aocsa commented on a change in pull request #11019:
URL: https://github.com/apache/arrow/pull/11019#discussion_r703116702
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1799,736 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>;
+const auto kDefaultTopKOptions = SelectKOptions::TopKDefault();
+const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault();
+
+const FunctionDoc top_k_doc(
+ "Returns the first k elements ordered by `options.keys` in ascending
order",
+ ("This function computes the k largest elements in ascending order of the
input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+const FunctionDoc bottom_k_doc(
+ "Returns the first k elements ordered by `options.keys` in descending
order",
+ ("This function computes the k smallest elements in descending order of
the input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForFixedSizedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ options_(options),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return VisitTypeInline(*physical_type_, this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+ Status Visit(const DataType& type) {
+ return Status::TypeError("Unsupported type for ArraySelecter: ",
type.ToString());
+ }
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (options_.k > arr.length()) {
+ options_.k = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ std::function<bool(uint64_t, uint64_t)> cmp;
+ SelectKComparator<sort_order> comparator;
+ cmp = [&arr, &comparator](uint64_t left, uint64_t right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto lval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (comparator(lval, rval)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ if (options_.keep_duplicates == true) {
+ iter = indices_begin;
+ for (; iter != end_iter; ++iter) {
+ if (*iter != heap.top()) {
+ const auto lval = GetView::LogicalValue(arr.GetView(*iter));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (lval == rval) {
+ heap.Push(*iter);
+ }
+ }
+ }
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForFixedSizedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ ARROW_ASSIGN_OR_RAISE(*output_, Take(array_,
Datum(std::move(take_indices)),
+ TakeOptions::NoBoundsCheck(), ctx_));
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ SelectKOptions options_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ options_(options),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (options_.k > chunked_array_.length()) {
+ options_.k = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() <
static_cast<size_t>(options_.k); ++iter) {
+ heap.Push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.ReplaceTop(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ if (options_.keep_duplicates == true) {
+ offset = 0;
+ for (const auto& chunk : chunks_holder) {
+ ArrayType& arr = *chunk;
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto iter = indices_begin;
+ for (; iter != indices_end; ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (x_index + offset != top_item.index + top_item.offset) {
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (xval == top_value) {
+ heap.Push(HeapItem{x_index, offset, &arr});
+ }
+ }
+ }
+ offset += chunk->length();
+ }
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForFixedSizedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.Pop();
+ --out_cbegin;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto chunked_select_k,
+ Take(Datum(chunked_array_),
Datum(std::move(take_indices)),
+ TakeOptions::NoBoundsCheck(), ctx_));
+ ARROW_ASSIGN_OR_RAISE(
+ auto select_k,
+ Concatenate(chunked_select_k.chunked_array()->chunks(),
ctx_->memory_pool()));
Review comment:
Update: I changed these APIs to return only indices like SortIndices.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]