wesm commented on a change in pull request #7442:
URL: https://github.com/apache/arrow/pull/7442#discussion_r441674808



##########
File path: cpp/src/arrow/compute/kernels/vector_selection.cc
##########
@@ -0,0 +1,1816 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstring>
+#include <limits>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/extension_type.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/int_util.h"
+
+namespace arrow {
+
+using internal::BinaryBitBlockCounter;
+using internal::BitBlockCount;
+using internal::BitBlockCounter;
+using internal::BitmapReader;
+using internal::CopyBitmap;
+using internal::GetArrayView;
+using internal::IndexBoundsCheck;
+using internal::OptionalBitBlockCounter;
+using internal::OptionalBitIndexer;
+
+namespace compute {
+namespace internal {
+
+int64_t GetFilterOutputSize(const ArrayData& filter,
+                            FilterOptions::NullSelectionBehavior 
null_selection) {
+  int64_t output_size = 0;
+  int64_t position = 0;
+  if (filter.GetNullCount() > 0) {
+    const uint8_t* filter_is_valid = filter.buffers[0]->data();
+    BinaryBitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset,
+                                      filter_is_valid, filter.offset, 
filter.length);
+    if (null_selection == FilterOptions::EMIT_NULL) {
+      while (position < filter.length) {
+        BitBlockCount block = bit_counter.NextOrNotWord();
+        output_size += block.popcount;
+        position += block.length;
+      }
+    } else {
+      while (position < filter.length) {
+        BitBlockCount block = bit_counter.NextAndWord();
+        output_size += block.popcount;
+        position += block.length;
+      }
+    }
+  } else {
+    // The filter has no nulls, so we plow through its data as fast as
+    // possible.
+    BitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset, 
filter.length);
+    while (position < filter.length) {
+      BitBlockCount block = bit_counter.NextFourWords();
+      output_size += block.popcount;
+      position += block.length;
+    }
+  }
+  return output_size;
+}
+
+template <typename IndexType>
+Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl(
+    const ArrayData& filter, FilterOptions::NullSelectionBehavior 
null_selection,
+    MemoryPool* memory_pool) {
+  using T = typename IndexType::c_type;
+  typename TypeTraits<IndexType>::BuilderType builder(memory_pool);
+
+  const uint8_t* filter_data = filter.buffers[1]->data();
+  BitBlockCounter data_counter(filter_data, filter.offset, filter.length);
+
+  // The position relative to the start of the filter
+  T position = 0;
+
+  // The current position taking the filter offset into account
+  int64_t position_with_offset = filter.offset;
+  if (filter.GetNullCount() > 0) {
+    // The filter has nulls, so we scan the validity bitmap and the filter data
+    // bitmap together, branching on the null selection type.
+    const uint8_t* filter_is_valid = filter.buffers[0]->data();
+
+    // To count blocks whether filter_data[i] || !filter_is_valid[i]
+    BinaryBitBlockCounter filter_counter(filter_data, filter.offset, 
filter_is_valid,
+                                         filter.offset, filter.length);
+    if (null_selection == FilterOptions::DROP) {
+      while (position < filter.length) {
+        BitBlockCount and_block = filter_counter.NextAndWord();
+        RETURN_NOT_OK(builder.Reserve(and_block.popcount));
+        if (and_block.AllSet()) {
+          // All the values are selected and non-null
+          for (int64_t i = 0; i < and_block.length; ++i) {
+            builder.UnsafeAppend(position++);
+          }
+          position_with_offset += and_block.length;
+        } else {
+          // Some of the values are false or null
+          for (int64_t i = 0; i < and_block.length; ++i) {
+            if (BitUtil::GetBit(filter_is_valid, position_with_offset) &&
+                BitUtil::GetBit(filter_data, position_with_offset)) {
+              builder.UnsafeAppend(position);
+            }
+            ++position;
+            ++position_with_offset;
+          }
+        }
+      }
+    } else {
+      BitBlockCounter is_valid_counter(filter_is_valid, filter.offset, 
filter.length);
+      while (position < filter.length) {
+        // true OR NOT valid
+        BitBlockCount or_not_block = filter_counter.NextOrNotWord();
+        RETURN_NOT_OK(builder.Reserve(or_not_block.popcount));
+
+        // If the values are all valid and the or_not_block is full, then we
+        // can infer that all the values are true and skip the bit checking
+        BitBlockCount is_valid_block = is_valid_counter.NextWord();
+
+        if (or_not_block.AllSet() && is_valid_block.AllSet()) {
+          // All the values are selected and non-null
+          for (int64_t i = 0; i < or_not_block.length; ++i) {
+            builder.UnsafeAppend(position++);
+          }
+          position_with_offset += or_not_block.length;
+        } else {
+          // Some of the values are false or null
+          for (int64_t i = 0; i < or_not_block.length; ++i) {
+            if (BitUtil::GetBit(filter_is_valid, position_with_offset)) {
+              if (BitUtil::GetBit(filter_data, position_with_offset)) {
+                builder.UnsafeAppend(position);
+              }
+            } else {
+              // Null slot, so append a null
+              builder.UnsafeAppendNull();
+            }
+            ++position;
+            ++position_with_offset;
+          }
+        }
+      }
+    }
+  } else {
+    // The filter has no nulls, so we need only look for true values
+    BitBlockCount current_block = data_counter.NextWord();
+    while (position < filter.length) {
+      if (current_block.AllSet()) {
+        int64_t run_length = 0;
+
+        // If we've found a all-true block, then we scan forward until we find
+        // a block that has some false values (or we reach the end
+        while (current_block.length > 0 && current_block.AllSet()) {
+          run_length += current_block.length;
+          current_block = data_counter.NextWord();
+        }
+
+        // Append the consecutive run of indices
+        RETURN_NOT_OK(builder.Reserve(run_length));
+        for (int64_t i = 0; i < run_length; ++i) {
+          builder.UnsafeAppend(position++);
+        }
+        position_with_offset += run_length;
+      } else {
+        // Must do bitchecking on the current block
+        RETURN_NOT_OK(builder.Reserve(current_block.popcount));
+        for (int64_t i = 0; i < current_block.length; ++i) {
+          if (BitUtil::GetBit(filter_data, position_with_offset)) {
+            builder.UnsafeAppend(position);
+          }
+          ++position;
+          ++position_with_offset;
+        }
+        current_block = data_counter.NextWord();
+      }
+    }
+  }
+  std::shared_ptr<ArrayData> result;
+  RETURN_NOT_OK(builder.FinishInternal(&result));
+  return result;
+}
+
+Result<std::shared_ptr<ArrayData>> GetTakeIndices(
+    const ArrayData& filter, FilterOptions::NullSelectionBehavior 
null_selection,
+    MemoryPool* memory_pool) {
+  DCHECK_EQ(filter.type->id(), Type::BOOL);
+  if (filter.length <= std::numeric_limits<uint16_t>::max()) {
+    return GetTakeIndicesImpl<UInt16Type>(filter, null_selection, memory_pool);
+  } else if (filter.length <= std::numeric_limits<uint32_t>::max()) {
+    return GetTakeIndicesImpl<UInt32Type>(filter, null_selection, memory_pool);
+  } else {
+    // Arrays over 4 billion elements, not especially likely.
+    return Status::NotImplemented(
+        "Filter length exceeds UINT32_MAX, "
+        "consider a different strategy for selecting elements");
+  }
+}
+
+namespace {
+
+template <typename ArrowType>
+struct GetCType {
+  using type = typename ArrowType::c_type;
+};
+
+// We want uint8_t for boolean instead of bool
+template <>
+struct GetCType<BooleanType> {
+  using type = uint8_t;
+};
+
+using FilterState = OptionsWrapper<FilterOptions>;
+using TakeState = OptionsWrapper<TakeOptions>;
+
+Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width, 
Datum* out) {
+  // Preallocate memory
+  ArrayData* out_arr = out->mutable_array();
+  out_arr->length = length;
+  out_arr->buffers.resize(2);
+
+  ARROW_ASSIGN_OR_RAISE(out_arr->buffers[0], ctx->AllocateBitmap(length));
+  if (bit_width == 1) {
+    ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->AllocateBitmap(length));
+  } else {
+    ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->Allocate(length * 
bit_width / 8));
+  }
+  return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement optimized take for primitive types from boolean to 1/2/4/8-byte
+// C-type based types. Use common implementation for every byte width and only
+// generate code for unsigned integer indices, since after boundschecking to
+// check for negative numbers in the indices we can safely reinterpret_cast
+// signed integers as unsigned.
+
+/// \brief The Take implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather the physical C type. This way we
+/// only generate one take function for each byte width.
+///
+/// This function assumes that the indices have been boundschecked.
+template <typename IndexCType, typename ValueCType>
+struct PrimitiveTakeImpl {
+  static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+                   Datum* out_datum) {
+    auto values_data = reinterpret_cast<const ValueCType*>(values.data);
+    auto values_is_valid = values.is_valid;
+    auto values_offset = values.offset;
+
+    auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+    auto indices_is_valid = indices.is_valid;
+    auto indices_offset = indices.offset;
+
+    ArrayData* out_arr = out_datum->mutable_array();
+    auto out = out_arr->GetMutableValues<ValueCType>(1);
+    auto out_is_valid = out_arr->buffers[0]->mutable_data();
+    auto out_offset = out_arr->offset;
+
+    // If either the values or indices have nulls, we preemptively zero out the
+    // out validity bitmap so that we don't have to use ClearBit in each
+    // iteration for nulls.
+    if (values.null_count > 0 || indices.null_count > 0) {
+      BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+    }
+
+    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
+                                                indices.length);
+    int64_t position = 0;
+    int64_t valid_count = 0;
+    while (position < indices.length) {
+      BitBlockCount block = indices_bit_counter.NextBlock();
+      if (values.null_count == 0) {
+        // Values are never null, so things are easier
+        valid_count += block.popcount;
+        if (block.popcount == block.length) {
+          // Fastest path: neither values nor index nulls
+          BitUtil::SetBitsTo(out_is_valid, out_offset + position, 
block.length, true);
+          for (int64_t i = 0; i < block.length; ++i) {
+            out[position] = values_data[indices_data[position]];
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some indices but not all are null
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+              // index is not null
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              out[position] = values_data[indices_data[position]];
+            } else {
+              out[position] = ValueCType{};
+            }
+            ++position;
+          }
+        } else {
+          memset(out + position, 0, sizeof(ValueCType) * block.length);
+          position += block.length;
+        }
+      } else {
+        // Values have nulls, so we must do random access into the values 
bitmap
+        if (block.popcount == block.length) {
+          // Faster path: indices are not null but values may be
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(values_is_valid,
+                                values_offset + indices_data[position])) {
+              // value is not null
+              out[position] = values_data[indices_data[position]];
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              ++valid_count;
+            } else {
+              out[position] = ValueCType{};
+            }
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some but not all indices are null. Since we are doing
+          // random access in general we have to check the value nullness one 
by
+          // one.
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position) &&
+                BitUtil::GetBit(values_is_valid,
+                                values_offset + indices_data[position])) {
+              // index is not null && value is not null
+              out[position] = values_data[indices_data[position]];
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              ++valid_count;
+            } else {
+              out[position] = ValueCType{};
+            }
+            ++position;
+          }
+        } else {
+          memset(out + position, 0, sizeof(ValueCType) * block.length);
+          position += block.length;
+        }
+      }
+    }
+    out_arr->null_count = out_arr->length - valid_count;
+  }
+};
+
+template <typename IndexCType>
+struct BooleanTakeImpl {
+  static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+                   Datum* out_datum) {
+    const uint8_t* values_data = values.data;
+    auto values_is_valid = values.is_valid;
+    auto values_offset = values.offset;
+
+    auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+    auto indices_is_valid = indices.is_valid;
+    auto indices_offset = indices.offset;
+
+    ArrayData* out_arr = out_datum->mutable_array();
+    auto out = out_arr->buffers[1]->mutable_data();
+    auto out_is_valid = out_arr->buffers[0]->mutable_data();
+    auto out_offset = out_arr->offset;
+
+    // If either the values or indices have nulls, we preemptively zero out the
+    // out validity bitmap so that we don't have to use ClearBit in each
+    // iteration for nulls.
+    if (values.null_count > 0 || indices.null_count > 0) {
+      BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+    }
+    // Avoid uninitialized data in values array
+    BitUtil::SetBitsTo(out, out_offset, indices.length, false);
+
+    auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
+      BitUtil::SetBitTo(out, out_offset + loc,
+                        BitUtil::GetBit(values_data, values_offset + index));
+    };
+
+    OptionalBitBlockCounter indices_bit_counter(indices_is_valid, 
indices_offset,
+                                                indices.length);
+    int64_t position = 0;
+    int64_t valid_count = 0;
+    while (position < indices.length) {
+      BitBlockCount block = indices_bit_counter.NextBlock();
+      if (values.null_count == 0) {
+        // Values are never null, so things are easier
+        valid_count += block.popcount;
+        if (block.popcount == block.length) {
+          // Fastest path: neither values nor index nulls
+          BitUtil::SetBitsTo(out_is_valid, out_offset + position, 
block.length, true);
+          for (int64_t i = 0; i < block.length; ++i) {
+            PlaceDataBit(position, indices_data[position]);
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some but not all indices are null
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+              // index is not null
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              PlaceDataBit(position, indices_data[position]);
+            }
+            ++position;
+          }
+        } else {
+          position += block.length;
+        }
+      } else {
+        // Values have nulls, so we must do random access into the values 
bitmap
+        if (block.popcount == block.length) {
+          // Faster path: indices are not null but values may be
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(values_is_valid,
+                                values_offset + indices_data[position])) {
+              // value is not null
+              BitUtil::SetBit(out_is_valid, out_offset + position);
+              PlaceDataBit(position, indices_data[position]);
+              ++valid_count;
+            }
+            ++position;
+          }
+        } else if (block.popcount > 0) {
+          // Slow path: some but not all indices are null. Since we are doing
+          // random access in general we have to check the value nullness one 
by
+          // one.
+          for (int64_t i = 0; i < block.length; ++i) {
+            if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+              // index is not null
+              if (BitUtil::GetBit(values_is_valid,
+                                  values_offset + indices_data[position])) {
+                // value is not null
+                PlaceDataBit(position, indices_data[position]);
+                BitUtil::SetBit(out_is_valid, out_offset + position);
+                ++valid_count;
+              }
+            }
+            ++position;
+          }
+        } else {
+          position += block.length;
+        }
+      }
+    }
+    out_arr->null_count = out_arr->length - valid_count;
+  }
+};
+
+template <template <typename...> class TakeImpl, typename... Args>
+void TakeIndexDispatch(const PrimitiveArg& values, const PrimitiveArg& indices,
+                       Datum* out) {
+  // With the simplifying assumption that boundschecking has taken place
+  // already at a higher level, we can now assume that the index values are all
+  // non-negative. Thus, we can interpret signed integers as unsigned and avoid
+  // having to generate double the amount of binary code to handle each integer
+  // width.
+  switch (indices.bit_width) {
+    case 8:
+      return TakeImpl<uint8_t, Args...>::Exec(values, indices, out);
+    case 16:
+      return TakeImpl<uint16_t, Args...>::Exec(values, indices, out);
+    case 32:
+      return TakeImpl<uint32_t, Args...>::Exec(values, indices, out);
+    case 64:
+      return TakeImpl<uint64_t, Args...>::Exec(values, indices, out);
+    default:
+      DCHECK(false) << "Invalid indices byte width";
+      break;
+  }
+}
+
+void PrimitiveTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const TakeState&>(*ctx->state());
+  if (state.options.boundscheck) {
+    KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), 
batch[0].length()));
+  }
+
+  PrimitiveArg values = GetPrimitiveArg(*batch[0].array());
+  PrimitiveArg indices = GetPrimitiveArg(*batch[1].array());
+  KERNEL_RETURN_IF_ERROR(ctx,
+                         PreallocateData(ctx, indices.length, 
values.bit_width, out));
+  switch (values.bit_width) {
+    case 1:
+      return TakeIndexDispatch<BooleanTakeImpl>(values, indices, out);
+    case 8:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int8_t>(values, indices, 
out);
+    case 16:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int16_t>(values, indices, 
out);
+    case 32:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int32_t>(values, indices, 
out);
+    case 64:
+      return TakeIndexDispatch<PrimitiveTakeImpl, int64_t>(values, indices, 
out);
+    default:
+      DCHECK(false) << "Invalid values byte width";
+      break;
+  }
+}
+
+// ----------------------------------------------------------------------
+// Optimized and streamlined filter for primitive types
+
+// Use either BitBlockCounter or BinaryBitBlockCounter to quickly scan filter a
+// word at a time for the DROP selection type.
+class DropNullCounter {
+ public:
+  // validity bitmap may be null
+  DropNullCounter(const uint8_t* validity, const uint8_t* data, int64_t offset,
+                  int64_t length)
+      : data_counter_(data, offset, length),
+        data_and_validity_counter_(data, offset, validity, offset, length),
+        has_validity_(validity != nullptr) {}
+
+  BitBlockCount NextBlock() {
+    if (has_validity_) {
+      // filter is true AND not null
+      return data_and_validity_counter_.NextAndWord();
+    } else {
+      return data_counter_.NextWord();
+    }
+  }
+
+ private:
+  // For when just data is present, but no validity bitmap
+  BitBlockCounter data_counter_;
+
+  // For when both validity bitmap and data are present
+  BinaryBitBlockCounter data_and_validity_counter_;
+  bool has_validity_;
+};
+
+/// \brief The Filter implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather then physical C type. This way we
+/// only generate one take function for each byte width. We use the same
+/// implementation here for boolean and fixed-byte-size inputs with some
+/// template specialization.
+template <typename ArrowType>
+class PrimitiveFilterImpl {
+ public:
+  using T = typename GetCType<ArrowType>::type;
+
+  PrimitiveFilterImpl(const PrimitiveArg& values, const PrimitiveArg& filter,
+                      FilterOptions::NullSelectionBehavior null_selection,
+                      Datum* out_datum)
+      : values_is_valid_(values.is_valid),
+        values_data_(reinterpret_cast<const T*>(values.data)),
+        values_null_count_(values.null_count),
+        values_offset_(values.offset),
+        values_length_(values.length),
+        filter_is_valid_(filter.is_valid),
+        filter_data_(filter.data),
+        filter_null_count_(filter.null_count),
+        filter_offset_(filter.offset),
+        null_selection_(null_selection) {
+    ArrayData* out_arr = out_datum->mutable_array();
+    out_is_valid_ = out_arr->buffers[0]->mutable_data();
+    out_data_ = reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data());
+    out_offset_ = out_arr->offset;
+    out_length_ = out_arr->length;
+    out_position_ = 0;
+  }
+
+  void ExecNonNull() {
+    // The result is all not-null
+    BitUtil::SetBitsTo(out_is_valid_, out_offset_ + out_position_, 
out_length_, true);
+
+    // Fast filter when values and filter are not null
+    // Bit counters used for both null_selection behaviors
+    BitBlockCounter filter_counter(filter_data_, filter_offset_, 
values_length_);
+
+    int64_t in_position = 0;
+    BitBlockCount current_block = filter_counter.NextWord();
+    while (in_position < values_length_) {
+      if (current_block.AllSet()) {
+        int64_t run_length = 0;
+        // If we've found a all-true block, then we scan forward until we find
+        // a block that has some false values (or we reach the end
+        while (current_block.length > 0 && current_block.AllSet()) {
+          run_length += current_block.length;
+          current_block = filter_counter.NextWord();
+        }
+        WriteValueSegment(in_position, run_length);
+        in_position += run_length;
+      } else if (current_block.NoneSet()) {
+        // Nothing selected
+        in_position += current_block.length;
+        current_block = filter_counter.NextWord();
+      } else {
+        // Some values selected
+        for (int64_t i = 0; i < current_block.length; ++i) {
+          if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+            WriteValue(in_position);
+          }
+          ++in_position;
+        }
+        current_block = filter_counter.NextWord();
+      }
+    }
+  }
+
+  void Exec() {
+    if (filter_null_count_ == 0 && values_null_count_ == 0) {
+      return ExecNonNull();
+    }
+
+    // Bit counters used for both null_selection behaviors
+    DropNullCounter drop_null_counter(filter_is_valid_, filter_data_, 
filter_offset_,
+                                      values_length_);
+    OptionalBitBlockCounter data_counter(values_is_valid_, values_offset_,
+                                         values_length_);
+    OptionalBitBlockCounter filter_valid_counter(filter_is_valid_, 
filter_offset_,
+                                                 values_length_);
+
+    auto WriteNotNull = [&](int64_t index) {
+      BitUtil::SetBit(out_is_valid_, out_offset_ + out_position_);
+      // Increments out_position_
+      WriteValue(index);
+    };
+
+    auto WriteMaybeNull = [&](int64_t index) {
+      BitUtil::SetBitTo(out_is_valid_, out_offset_ + out_position_,
+                        BitUtil::GetBit(values_is_valid_, values_offset_ + 
index));
+      // Increments out_position_
+      WriteValue(index);
+    };
+
+    int64_t in_position = 0;
+    while (in_position < values_length_) {
+      BitBlockCount filter_block = drop_null_counter.NextBlock();
+      BitBlockCount filter_valid_block = filter_valid_counter.NextWord();
+      BitBlockCount data_block = data_counter.NextWord();
+      if (filter_block.AllSet() && data_block.AllSet()) {
+        // Fastest path: all values in block are included and not null
+        BitUtil::SetBitsTo(out_is_valid_, out_offset_ + out_position_,
+                           filter_block.length, true);
+        WriteValueSegment(in_position, filter_block.length);
+        in_position += filter_block.length;
+      } else if (filter_block.AllSet()) {
+        // Faster: all values are selected, but some values are null
+        // Batch copy bits from values validity bitmap to output validity 
bitmap
+        CopyBitmap(values_is_valid_, values_offset_ + in_position, 
filter_block.length,
+                   out_is_valid_, out_offset_ + out_position_);
+        WriteValueSegment(in_position, filter_block.length);
+        in_position += filter_block.length;
+      } else if (filter_block.NoneSet() && null_selection_ == 
FilterOptions::DROP) {
+        // For this exceedingly common case in low-selectivity filters we can
+        // skip further analysis of the data and move on to the next block.
+        in_position += filter_block.length;
+      } else {
+        // Some filter values are false or null
+        if (data_block.AllSet()) {
+          // No values are null
+          if (filter_valid_block.AllSet()) {
+            // Filter is non-null but some values are false
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) 
{
+                WriteNotNull(in_position);
+              }
+              ++in_position;
+            }
+          } else if (null_selection_ == FilterOptions::DROP) {
+            // If any values are selected, they ARE NOT null
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + 
in_position) &&
+                  BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) 
{
+                WriteNotNull(in_position);
+              }
+              ++in_position;
+            }
+          } else {  // null_selection == FilterOptions::EMIT_NULL
+            // Data values in this block are not null
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              const bool is_valid =
+                  BitUtil::GetBit(filter_is_valid_, filter_offset_ + 
in_position);
+              if (is_valid &&
+                  BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) 
{
+                // Filter slot is non-null and set
+                WriteNotNull(in_position);
+              } else if (!is_valid) {
+                // Filter slot is null, so we have a null in the output
+                BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_);
+                WriteNull();
+              }
+              ++in_position;
+            }
+          }
+        } else {  // !data_block.AllSet()
+          // Some values are null
+          if (filter_valid_block.AllSet()) {
+            // Filter is non-null but some values are false
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) 
{
+                WriteMaybeNull(in_position);
+              }
+              ++in_position;
+            }
+          } else if (null_selection_ == FilterOptions::DROP) {
+            // If any values are selected, they ARE NOT null
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + 
in_position) &&
+                  BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) 
{
+                WriteMaybeNull(in_position);
+              }
+              ++in_position;
+            }
+          } else {  // null_selection == FilterOptions::EMIT_NULL
+            // Data values in this block are not null
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              const bool is_valid =
+                  BitUtil::GetBit(filter_is_valid_, filter_offset_ + 
in_position);
+              if (is_valid &&
+                  BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) 
{
+                // Filter slot is non-null and set
+                WriteMaybeNull(in_position);
+              } else if (!is_valid) {
+                // Filter slot is null, so we have a null in the output
+                BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_);
+                WriteNull();
+              }
+              ++in_position;
+            }
+          }
+        }
+      }  // !filter_block.AllSet()
+    }    // while(in_position < values_length_)
+  }
+
+  // Write the next out_position given the selected in_position for the input
+  // data and advance out_position
+  void WriteValue(int64_t in_position) {
+    out_data_[out_position_++] = values_data_[in_position];
+  }
+
+  void WriteValueSegment(int64_t in_start, int64_t length) {
+    std::memcpy(out_data_ + out_position_, values_data_ + in_start, length * 
sizeof(T));
+    out_position_ += length;
+  }
+
+  void WriteNull() {
+    // Zero the memory
+    out_data_[out_position_++] = T{};
+  }
+
+ private:
+  const uint8_t* values_is_valid_;
+  const T* values_data_;
+  int64_t values_null_count_;
+  int64_t values_offset_;
+  int64_t values_length_;
+  const uint8_t* filter_is_valid_;
+  const uint8_t* filter_data_;
+  int64_t filter_null_count_;
+  int64_t filter_offset_;
+  FilterOptions::NullSelectionBehavior null_selection_;
+  uint8_t* out_is_valid_;
+  T* out_data_;
+  int64_t out_offset_;
+  int64_t out_length_;
+  int64_t out_position_;
+};
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteValue(int64_t in_position) {
+  BitUtil::SetBitTo(out_data_, out_offset_ + out_position_++,
+                    BitUtil::GetBit(values_data_, values_offset_ + 
in_position));
+}
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteValueSegment(int64_t 
in_start,
+                                                                int64_t 
length) {
+  CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_,
+             out_offset_ + out_position_);
+  out_position_ += length;
+}
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteNull() {
+  // Zero the bit
+  BitUtil::ClearBit(out_data_, out_offset_ + out_position_++);
+}
+
+void PrimitiveFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const FilterState&>(*ctx->state());
+  PrimitiveArg values = GetPrimitiveArg(*batch[0].array());
+  PrimitiveArg filter = GetPrimitiveArg(*batch[1].array());
+  FilterOptions::NullSelectionBehavior null_selection =
+      state.options.null_selection_behavior;
+
+  int64_t output_length = GetFilterOutputSize(*batch[1].array(), 
null_selection);
+  KERNEL_RETURN_IF_ERROR(ctx, PreallocateData(ctx, output_length, 
values.bit_width, out));
+
+  // The output precomputed null count is unknown except in the narrow
+  // condition that all the values are non-null and the filter will not cause
+  // any new nulls to be created.
+  if (values.null_count == 0 &&
+      (null_selection == FilterOptions::DROP || filter.null_count == 0)) {
+    out->mutable_array()->null_count = 0;
+  } else {
+    out->mutable_array()->null_count = kUnknownNullCount;
+  }
+  switch (values.bit_width) {
+    case 1:
+      return PrimitiveFilterImpl<BooleanType>(values, filter, null_selection, 
out).Exec();
+    case 8:
+      return PrimitiveFilterImpl<UInt8Type>(values, filter, null_selection, 
out).Exec();
+    case 16:
+      return PrimitiveFilterImpl<UInt16Type>(values, filter, null_selection, 
out).Exec();
+    case 32:
+      return PrimitiveFilterImpl<UInt32Type>(values, filter, null_selection, 
out).Exec();
+    case 64:
+      return PrimitiveFilterImpl<UInt64Type>(values, filter, null_selection, 
out).Exec();
+    default:
+      DCHECK(false) << "Invalid values bit width";
+      break;
+  }
+}
+
+// ----------------------------------------------------------------------
+// Null take and filter
+
+void NullTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const TakeState&>(*ctx->state());
+  if (state.options.boundscheck) {
+    KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), 
batch[0].length()));
+  }
+  out->value = std::make_shared<NullArray>(batch.length)->data();
+}
+
+void NullFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const FilterState&>(*ctx->state());
+  int64_t output_length =
+      GetFilterOutputSize(*batch[1].array(), 
state.options.null_selection_behavior);
+  out->value = std::make_shared<NullArray>(output_length)->data();
+}
+
+// ----------------------------------------------------------------------
+// Dictionary take and filter
+
+void DictionaryTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const TakeState&>(*ctx->state());
+  DictionaryArray values(batch[0].array());
+  Datum result;
+  KERNEL_RETURN_IF_ERROR(
+      ctx, Take(Datum(values.indices()), batch[1], state.options, 
ctx->exec_context())
+               .Value(&result));
+  DictionaryArray taken_values(values.type(), result.make_array(), 
values.dictionary());
+  out->value = taken_values.data();
+}
+
+void DictionaryFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const FilterState&>(*ctx->state());
+  DictionaryArray dict_values(batch[0].array());
+  Datum result;
+  KERNEL_RETURN_IF_ERROR(ctx, Filter(Datum(dict_values.indices()), 
batch[1].array(),
+                                     state.options, ctx->exec_context())
+                                  .Value(&result));
+  DictionaryArray filtered_values(dict_values.type(), result.make_array(),
+                                  dict_values.dictionary());
+  out->value = filtered_values.data();
+}
+
+// ----------------------------------------------------------------------
+// Extension take and filter
+
+void ExtensionTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const TakeState&>(*ctx->state());
+  ExtensionArray values(batch[0].array());
+  Datum result;
+  KERNEL_RETURN_IF_ERROR(
+      ctx, Take(Datum(values.storage()), batch[1], state.options, 
ctx->exec_context())
+               .Value(&result));
+  ExtensionArray taken_values(values.type(), result.make_array());
+  out->value = taken_values.data();
+}
+
+void ExtensionFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+  const auto& state = checked_cast<const FilterState&>(*ctx->state());
+  ExtensionArray ext_values(batch[0].array());
+  Datum result;
+  KERNEL_RETURN_IF_ERROR(ctx, Filter(Datum(ext_values.storage()), 
batch[1].array(),
+                                     state.options, ctx->exec_context())
+                                  .Value(&result));
+  ExtensionArray filtered_values(ext_values.type(), result.make_array());
+  out->value = filtered_values.data();
+}
+
+// ----------------------------------------------------------------------
+// Implement take for other data types where there is less performance
+// sensitivity by visiting the selected indices.
+
+// Use CRTP to dispatch to type-specific processing of take indices for each
+// unsigned integer type.
+template <typename Impl, typename Type>
+struct Selection {
+  using ValuesArrayType = typename TypeTraits<Type>::ArrayType;
+
+  // Forwards the generic value visitors to the take index visitor template
+  template <typename IndexCType>
+  struct TakeAdapter {
+    static constexpr bool is_take = true;
+
+    Impl* impl;
+    explicit TakeAdapter(Impl* impl) : impl(impl) {}
+    template <typename ValidVisitor, typename NullVisitor>
+    Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+      return impl->template 
VisitTake<IndexCType>(std::forward<ValidVisitor>(visit_valid),
+                                                  
std::forward<NullVisitor>(visit_null));
+    }
+  };
+
+  // Forwards the generic value visitors to the VisitFilter template
+  struct FilterAdapter {
+    static constexpr bool is_take = false;
+
+    Impl* impl;
+    explicit FilterAdapter(Impl* impl) : impl(impl) {}
+    template <typename ValidVisitor, typename NullVisitor>
+    Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+      return impl->VisitFilter(std::forward<ValidVisitor>(visit_valid),
+                               std::forward<NullVisitor>(visit_null));
+    }
+  };
+
+  KernelContext* ctx;
+  std::shared_ptr<ArrayData> values;
+  std::shared_ptr<ArrayData> selection;
+  int64_t output_length;
+  ArrayData* out;
+  TypedBufferBuilder<bool> validity_builder;
+
+  Selection(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, 
Datum* out)
+      : ctx(ctx),
+        values(batch[0].array()),
+        selection(batch[1].array()),
+        output_length(output_length),
+        out(out->mutable_array()),
+        validity_builder(ctx->memory_pool()) {}
+
+  virtual ~Selection() = default;
+
+  Status FinishCommon() {
+    out->buffers.resize(values->buffers.size());
+    out->length = validity_builder.length();
+    out->null_count = validity_builder.false_count();
+    return validity_builder.Finish(&out->buffers[0]);
+  }
+
+  template <typename IndexCType, typename ValidVisitor, typename NullVisitor>
+  Status VisitTake(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+    const auto indices_values = selection->GetValues<IndexCType>(1);
+    const uint8_t* is_valid = GetValidityBitmap(*selection);
+    OptionalBitIndexer indices_is_valid(selection->buffers[0], 
selection->offset);
+    OptionalBitIndexer values_is_valid(values->buffers[0], values->offset);
+    const bool values_have_nulls = (values->GetNullCount() > 0);
+
+    OptionalBitBlockCounter bit_counter(is_valid, selection->offset, 
selection->length);
+    int64_t position = 0;
+    while (position < selection->length) {
+      BitBlockCount block = bit_counter.NextBlock();
+      const bool indices_have_nulls = block.popcount < block.length;
+      if (!indices_have_nulls && !values_have_nulls) {
+        // Fastest path, neither indices nor values have nulls
+        validity_builder.UnsafeAppend(block.length, true);
+        for (int64_t i = 0; i < block.length; ++i) {
+          RETURN_NOT_OK(visit_valid(indices_values[position++]));
+        }
+      } else if (block.popcount > 0) {
+        // Since we have to branch on whether the indices are null or not, we
+        // combine the "non-null indices block but some values null" and
+        // "some-null indices block but values non-null" into a single loop.
+        for (int64_t i = 0; i < block.length; ++i) {
+          if ((!indices_have_nulls || indices_is_valid[position]) &&
+              values_is_valid[indices_values[position]]) {
+            validity_builder.UnsafeAppend(true);
+            RETURN_NOT_OK(visit_valid(indices_values[position]));
+          } else {
+            validity_builder.UnsafeAppend(false);
+            RETURN_NOT_OK(visit_null());
+          }
+          ++position;
+        }
+      } else {
+        // The whole block is null
+        validity_builder.UnsafeAppend(block.length, false);
+        for (int64_t i = 0; i < block.length; ++i) {
+          RETURN_NOT_OK(visit_null());
+        }
+        position += block.length;
+      }
+    }
+    return Status::OK();
+  }
+
+  // We use the NullVisitor both for "selected" nulls as well as "emitted"
+  // nulls coming from the filter when using FilterOptions::EMIT_NULL
+  template <typename ValidVisitor, typename NullVisitor>
+  Status VisitFilter(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+    const auto& state = checked_cast<const FilterState&>(*ctx->state());
+    auto null_selection = state.options.null_selection_behavior;
+
+    const auto filter_data = selection->buffers[1]->data();
+
+    const uint8_t* filter_is_valid = GetValidityBitmap(*selection);
+    const int64_t filter_offset = selection->offset;
+    OptionalBitIndexer values_is_valid(values->buffers[0], values->offset);
+
+    // We use 3 block counters for fast scanning of the filter
+    //
+    // * values_valid_counter: for values null/not-null
+    // * filter_valid_counter: for filter null/not-null
+    // * filter_counter: for filter true/false
+    OptionalBitBlockCounter values_valid_counter(GetValidityBitmap(*values),
+                                                 values->offset, 
values->length);
+    OptionalBitBlockCounter filter_valid_counter(filter_is_valid, 
filter_offset,
+                                                 selection->length);
+    BitBlockCounter filter_counter(filter_data, filter_offset, 
selection->length);
+    int64_t in_position = 0;
+
+    auto AppendNotNull = [&](int64_t index) -> Status {
+      validity_builder.UnsafeAppend(true);
+      return visit_valid(index);
+    };
+
+    auto AppendNull = [&]() -> Status {
+      validity_builder.UnsafeAppend(false);
+      return visit_null();
+    };
+
+    auto AppendMaybeNull = [&](int64_t index) -> Status {
+      if (values_is_valid[index]) {
+        return AppendNotNull(index);
+      } else {
+        return AppendNull();
+      }
+    };
+
+    while (in_position < selection->length) {
+      BitBlockCount filter_valid_block = filter_valid_counter.NextWord();
+      BitBlockCount values_valid_block = values_valid_counter.NextWord();
+      BitBlockCount filter_block = filter_counter.NextWord();
+      if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) {
+        // For this exceedingly common case in low-selectivity filters we can
+        // skip further analysis of the data and move on to the next block.
+        in_position += filter_block.length;
+      } else if (filter_valid_block.AllSet()) {
+        // Simpler path: no filter values are null
+        if (filter_block.AllSet()) {
+          // Fastest path: filter values are all true and not null
+          if (values_valid_block.AllSet()) {
+            // The values aren't null either
+            validity_builder.UnsafeAppend(filter_block.length, true);
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              RETURN_NOT_OK(visit_valid(in_position++));
+            }
+          } else {
+            // Some of the values in this block are null
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              RETURN_NOT_OK(AppendMaybeNull(in_position++));
+            }
+          }
+        } else {  // !filter_block.AllSet()
+          // Some of the filter values are false, but all not null
+          if (values_valid_block.AllSet()) {
+            // All the values are not-null, so we can skip null checking for
+            // them
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+                RETURN_NOT_OK(AppendNotNull(in_position));
+              }
+              ++in_position;
+            }
+          } else {
+            // Some of the values in the block are null, so we have to check
+            // each one
+            for (int64_t i = 0; i < filter_block.length; ++i) {
+              if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+                RETURN_NOT_OK(AppendMaybeNull(in_position));
+              }
+              ++in_position;
+            }
+          }
+        }
+      } else {  // !filter_valid_block.AllSet()
+        // Some of the filter values are null, so we have to handle the DROP
+        // versus EMIT_NULL null selection behavior.
+        if (null_selection == FilterOptions::DROP) {

Review comment:
       I'm extracting some code into lambdas and doing this transform




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to