lidavidm commented on a change in pull request #11019:
URL: https://github.com/apache/arrow/pull/11019#discussion_r704386815
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,75 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+// \brief Selection algorithm. Default is `NonStableSelect` which uses a Heap
based
+// algorithm.
+enum class SelectKAlgorithm { NonStableSelect, StableSelect };
Review comment:
Hmm. Why not just have different kernels, but using the same options
classes?
##########
File path: python/pyarrow/includes/libarrow.pxd
##########
@@ -2031,6 +2031,33 @@ cdef extern from "arrow/compute/api.h" namespace
"arrow::compute" nogil:
CSortOptions(vector[CSortKey] sort_keys)
vector[CSortKey] sort_keys
+ cdef enum CSelectKAlgorithm" arrow::compute::SelectKAlgorithm":
+ CSelectKAlgorithm_NonStableSelect \
+ "arrow::compute::SelectKAlgorithm::NonStableSelect"
+ CSelectKAlgorithm_StableSelect \
+ "arrow::compute::SelectKAlgorithm::StableSelect"
Review comment:
[MSVC can't handle
this](https://ci.appveyor.com/project/ApacheSoftwareFoundation/arrow/builds/40684234/job/54jfqvifygawaef3#L1935),
I think this needs to be `ctypedef enum` and not `cdef enum` following the
example of the SortOrder declaration.
##########
File path: cpp/src/arrow/compute/api_vector.cc
##########
@@ -140,6 +150,50 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot)
: FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {}
constexpr char PartitionNthOptions::kTypeName[];
+SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys,
+ SelectKAlgorithm kind)
+ : FunctionOptions(internal::kSelectKOptionsType),
+ k(k),
+ sort_keys(std::move(sort_keys)),
+ kind(kind) {}
+
+bool SelectKOptions::is_top_k() const {
+ SortOrder order = SortOrder::Descending;
+ for (const auto& k : sort_keys) {
+ order = k.order;
+ if (order != SortOrder::Descending) {
+ break;
+ }
+ }
+ return order == SortOrder::Descending;
+}
Review comment:
Why track `order`, instead of just iterating through the sort keys and
returning `false` if the order is wrong, returning `true` outside the loop?
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,75 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+// \brief Selection algorithm. Default is `NonStableSelect` which uses a Heap
based
+// algorithm.
+enum class SelectKAlgorithm { NonStableSelect, StableSelect };
Review comment:
So we'd (eventually) have
```cpp
SelectKStable({...}, SelectKOptions::TopK(5, {"key1"}));
SelectKUnstable({...}, SelectKOptions(5, {SortKey("key1", Ascending),
SortKey("key2", descending)}));
```
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,75 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+// \brief Selection algorithm. Default is `NonStableSelect` which uses a Heap
based
+// algorithm.
+enum class SelectKAlgorithm { NonStableSelect, StableSelect };
Review comment:
Again, I'm not sure why we need so many overloads for top/bottom/select.
I think it'd be cleaner to have just SelectK, and if needed, some convenience
constructors to easily construct a top/bottomK algorithm.
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,75 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+// \brief Selection algorithm. Default is `NonStableSelect` which uses a Heap
based
+// algorithm.
+enum class SelectKAlgorithm { NonStableSelect, StableSelect };
Review comment:
Especially because top/bottomK are implemented in terms of selectK
anyways - they're just conveniences, so let's push those conveniences to the
edges (i.e. only the options struct), instead of adding a lot of code to
support them in several places (which now requires more bindings, more code to
read in the actual implementation, more overloads, ...).
##########
File path: cpp/src/arrow/compute/api_vector.cc
##########
@@ -140,6 +150,50 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot)
: FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {}
constexpr char PartitionNthOptions::kTypeName[];
+SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys,
+ SelectKAlgorithm kind)
+ : FunctionOptions(internal::kSelectKOptionsType),
+ k(k),
+ sort_keys(std::move(sort_keys)),
+ kind(kind) {}
+
+bool SelectKOptions::is_top_k() const {
+ SortOrder order = SortOrder::Descending;
+ for (const auto& k : sort_keys) {
+ order = k.order;
+ if (order != SortOrder::Descending) {
+ break;
+ }
+ }
+ return order == SortOrder::Descending;
+}
Review comment:
```suggestion
bool SelectKOptions::is_top_k() const {
for (const auto& k : sort_keys) {
if (k.order != SortOrder::Descending) {
return false;
}
}
return true;
}
```
##########
File path: cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc
##########
@@ -0,0 +1,65 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/benchmark_util.h"
+
+namespace arrow {
+namespace compute {
+constexpr auto kSeed = 0x0ff1ce;
+
+Result<std::shared_ptr<Array>> TopKWithSorting(const Array& values, int64_t n)
{
Review comment:
This is unused?
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1799,711 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+const auto kDefaultTopKOptions = TopKOptions::Defaults();
+const auto kDefaultBottomKOptions = BottomKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+const FunctionDoc top_k_doc(
+ "Returns the first k elements ordered by `options.keys` in ascending
order",
+ ("This function computes the k largest elements in ascending order of the
input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "TopKOptions");
+
+const FunctionDoc bottom_k_doc(
+ "Returns the first k elements ordered by `options.keys` in descending
order",
+ ("This function computes the k smallest elements in descending order of
the input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "BottomKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ options_(options),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (options_.k > arr.length()) {
+ options_.k = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ SelectKComparator<sort_order> comparator;
+ auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto lval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (comparator(lval, rval)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ SelectKOptions options_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ options_(options),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (options_.k > chunked_array_.length()) {
+ options_.k = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() <
static_cast<size_t>(options_.k); ++iter) {
+ heap.Push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.ReplaceTop(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ SelectKOptions options_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto array = batch.GetColumnByName(key.name);
+ if (!array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(array, key.order);
+ }
+ return resolved;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const
ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > record_batch_.num_rows()) {
+ options_.k = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class TableSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key.name);
+ if (!chunked_array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(*chunked_array, key.order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For non-float types.
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t
index) {
+ const auto chunk =
first_sort_key.GetChunk<ArrayType>((int64_t)index);
+ return !chunk.IsNull();
+ });
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ auto& comparator = comparator_;
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nulls_begin;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For float types.
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ StablePartitioner partitioner;
+ uint64_t* nulls_begin;
+ if (first_sort_key.null_count == 0) {
+ nulls_begin = indices_end;
+ } else {
+ nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t
index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !chunk.IsNull();
+ });
+ }
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ uint64_t* nans_begin = partitioner(indices_begin, nulls_begin,
[&](uint64_t index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !std::isnan(chunk.Value());
+ });
+ auto& comparator = comparator_;
+ // Sort all NaNs by the second and following sort keys.
+ std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ // Sort all nulls by the second and following sort keys.
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nans_begin;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ const auto num_rows = table_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > table_.num_rows()) {
+ options_.k = table_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ auto chunk_left = first_sort_key.template GetChunk<ArrayType>(left);
+ auto chunk_right = first_sort_key.template GetChunk<ArrayType>(right);
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(value_left, value_right);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(num_rows);
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter =
+ this->PartitionNullsInternal<InType>(indices_begin, indices_end,
first_sort_key);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ uint64_t top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class SelectKthMetaFunction {
+ public:
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx)
const {
+ const SelectKOptions& select_k_options = static_cast<const
SelectKOptions&>(*options);
+ if (select_k_options.k < 0) {
+ return Status::Invalid("TopK/BottomK requires a valid `k` parameter");
+ }
+ switch (args[0].kind()) {
+ case Datum::ARRAY: {
+ if (select_k_options.is_top_k()) {
+ return SelectKth<SortOrder::Descending>(*args[0].make_array(),
select_k_options,
+ ctx);
+ } else {
+ return SelectKth<SortOrder::Ascending>(*args[0].make_array(),
select_k_options,
+ ctx);
+ }
+ } break;
+ case Datum::CHUNKED_ARRAY: {
+ if (select_k_options.is_top_k()) {
+ return SelectKth<SortOrder::Descending>(*args[0].chunked_array(),
+ select_k_options, ctx);
+ } else {
+ return SelectKth<SortOrder::Ascending>(*args[0].chunked_array(),
+ select_k_options, ctx);
+ }
+ } break;
+ case Datum::RECORD_BATCH:
+ return SelectKth(*args[0].record_batch(), select_k_options, ctx);
+ break;
+ case Datum::TABLE:
+ return SelectKth(*args[0].table(), select_k_options, ctx);
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for sort_indices operation: "
+ "values=",
+ args[0].ToString());
+ }
+
+ private:
+ template <SortOrder sort_order>
+ Result<Datum> SelectKth(const Array& array, const SelectKOptions& options,
+ ExecContext* ctx) const {
+ Datum output;
+ ArraySelecter<sort_order> selecter(ctx, array, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+
+ template <SortOrder sort_order>
+ Result<Datum> SelectKth(const ChunkedArray& chunked_array,
+ const SelectKOptions& options, ExecContext* ctx)
const {
+ Datum output;
+ ChunkedArraySelecter<sort_order> selecter(ctx, chunked_array, options,
&output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+ Result<Datum> SelectKth(const RecordBatch& record_batch, const
SelectKOptions& options,
+ ExecContext* ctx) const {
+ Datum output;
+ RecordBatchSelecter selecter(ctx, record_batch, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+ Result<Datum> SelectKth(const Table& table, const SelectKOptions& options,
+ ExecContext* ctx) const {
+ Datum output;
+ TableSelecter selecter(ctx, table, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+};
+
+class SelectKMetaFunction : public MetaFunction {
+ public:
+ SelectKMetaFunction()
+ : MetaFunction("select_k", Arity::Unary(), &select_k_doc,
&kDefaultSelectKOptions) {
+ }
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ SelectKthMetaFunction impl;
+ return impl.ExecuteImpl(args, options, ctx);
+ }
+};
+
+class TopKMetaFunction : public MetaFunction {
+ public:
+ TopKMetaFunction()
+ : MetaFunction("top_k", Arity::Unary(), &top_k_doc,
&kDefaultTopKOptions) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ SelectKthMetaFunction impl;
+ const TopKOptions& opts = static_cast<const TopKOptions&>(*options);
+
+ std::vector<SortKey> sort_keys;
+ for (const auto& name : opts.keys)
+ sort_keys.emplace_back(SortKey(name, TopKOptions::order()));
+ if (args[0].kind() == Datum::ARRAY || args[0].kind() ==
Datum::CHUNKED_ARRAY) {
+ sort_keys.emplace_back(SortKey("not-used", TopKOptions::order()));
+ }
+ SelectKOptions select_k_options(opts.k, sort_keys, opts.kind);
+ return impl.ExecuteImpl(args, &select_k_options, ctx);
+ }
+};
+
+class BottomKMetaFunction : public MetaFunction {
+ public:
+ BottomKMetaFunction()
+ : MetaFunction("bottom_k", Arity::Unary(), &bottom_k_doc,
&kDefaultBottomKOptions) {
+ }
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ SelectKthMetaFunction impl;
+ const BottomKOptions& opts = static_cast<const BottomKOptions&>(*options);
+
+ std::vector<SortKey> sort_keys;
+ for (const auto& name : opts.keys)
+ sort_keys.emplace_back(SortKey(name, BottomKOptions::order()));
+ if (args[0].kind() == Datum::ARRAY || args[0].kind() ==
Datum::CHUNKED_ARRAY) {
+ sort_keys.emplace_back(SortKey("not-used", BottomKOptions::order()));
+ }
+ SelectKOptions select_k_options(opts.k, sort_keys, opts.kind);
+ return impl.ExecuteImpl(args, &select_k_options, ctx);
+ }
+};
Review comment:
As stated above I would rather we not have explicit top_k and bottom_k
functions if they're just going to dispatch to select_k. We should have only
select_k, and top/bottom_k can be convenience constructors on the options
themselves.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1799,711 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+const auto kDefaultTopKOptions = TopKOptions::Defaults();
+const auto kDefaultBottomKOptions = BottomKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+const FunctionDoc top_k_doc(
+ "Returns the first k elements ordered by `options.keys` in ascending
order",
+ ("This function computes the k largest elements in ascending order of the
input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "TopKOptions");
+
+const FunctionDoc bottom_k_doc(
+ "Returns the first k elements ordered by `options.keys` in descending
order",
+ ("This function computes the k smallest elements in descending order of
the input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "BottomKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ options_(options),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (options_.k > arr.length()) {
+ options_.k = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ SelectKComparator<sort_order> comparator;
+ auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto lval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (comparator(lval, rval)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ SelectKOptions options_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ options_(options),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (options_.k > chunked_array_.length()) {
+ options_.k = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() <
static_cast<size_t>(options_.k); ++iter) {
+ heap.Push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.ReplaceTop(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ SelectKOptions options_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto array = batch.GetColumnByName(key.name);
+ if (!array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(array, key.order);
+ }
+ return resolved;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const
ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > record_batch_.num_rows()) {
+ options_.k = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class TableSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key.name);
+ if (!chunked_array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(*chunked_array, key.order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For non-float types.
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t
index) {
+ const auto chunk =
first_sort_key.GetChunk<ArrayType>((int64_t)index);
+ return !chunk.IsNull();
+ });
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ auto& comparator = comparator_;
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nulls_begin;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For float types.
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ StablePartitioner partitioner;
+ uint64_t* nulls_begin;
+ if (first_sort_key.null_count == 0) {
+ nulls_begin = indices_end;
+ } else {
+ nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t
index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !chunk.IsNull();
+ });
+ }
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ uint64_t* nans_begin = partitioner(indices_begin, nulls_begin,
[&](uint64_t index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !std::isnan(chunk.Value());
+ });
+ auto& comparator = comparator_;
+ // Sort all NaNs by the second and following sort keys.
+ std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ // Sort all nulls by the second and following sort keys.
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nans_begin;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ const auto num_rows = table_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > table_.num_rows()) {
+ options_.k = table_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ auto chunk_left = first_sort_key.template GetChunk<ArrayType>(left);
+ auto chunk_right = first_sort_key.template GetChunk<ArrayType>(right);
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(value_left, value_right);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(num_rows);
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter =
+ this->PartitionNullsInternal<InType>(indices_begin, indices_end,
first_sort_key);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ uint64_t top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class SelectKthMetaFunction {
+ public:
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx)
const {
+ const SelectKOptions& select_k_options = static_cast<const
SelectKOptions&>(*options);
+ if (select_k_options.k < 0) {
+ return Status::Invalid("TopK/BottomK requires a valid `k` parameter");
+ }
+ switch (args[0].kind()) {
+ case Datum::ARRAY: {
+ if (select_k_options.is_top_k()) {
+ return SelectKth<SortOrder::Descending>(*args[0].make_array(),
select_k_options,
+ ctx);
+ } else {
+ return SelectKth<SortOrder::Ascending>(*args[0].make_array(),
select_k_options,
+ ctx);
+ }
+ } break;
+ case Datum::CHUNKED_ARRAY: {
+ if (select_k_options.is_top_k()) {
+ return SelectKth<SortOrder::Descending>(*args[0].chunked_array(),
+ select_k_options, ctx);
+ } else {
+ return SelectKth<SortOrder::Ascending>(*args[0].chunked_array(),
+ select_k_options, ctx);
+ }
+ } break;
+ case Datum::RECORD_BATCH:
+ return SelectKth(*args[0].record_batch(), select_k_options, ctx);
+ break;
+ case Datum::TABLE:
+ return SelectKth(*args[0].table(), select_k_options, ctx);
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for sort_indices operation: "
+ "values=",
+ args[0].ToString());
+ }
+
+ private:
+ template <SortOrder sort_order>
+ Result<Datum> SelectKth(const Array& array, const SelectKOptions& options,
+ ExecContext* ctx) const {
+ Datum output;
+ ArraySelecter<sort_order> selecter(ctx, array, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+
+ template <SortOrder sort_order>
+ Result<Datum> SelectKth(const ChunkedArray& chunked_array,
+ const SelectKOptions& options, ExecContext* ctx)
const {
+ Datum output;
+ ChunkedArraySelecter<sort_order> selecter(ctx, chunked_array, options,
&output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+ Result<Datum> SelectKth(const RecordBatch& record_batch, const
SelectKOptions& options,
+ ExecContext* ctx) const {
+ Datum output;
+ RecordBatchSelecter selecter(ctx, record_batch, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+ Result<Datum> SelectKth(const Table& table, const SelectKOptions& options,
+ ExecContext* ctx) const {
+ Datum output;
+ TableSelecter selecter(ctx, table, options, &output);
+ ARROW_RETURN_NOT_OK(selecter.Run());
+ return output;
+ }
+};
+
+class SelectKMetaFunction : public MetaFunction {
+ public:
+ SelectKMetaFunction()
+ : MetaFunction("select_k", Arity::Unary(), &select_k_doc,
&kDefaultSelectKOptions) {
+ }
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ SelectKthMetaFunction impl;
Review comment:
Can we just inline the above class here? It doesn't seem useful to split
them. (Once we remove TopKMetaFunction et al)
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1799,711 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+const auto kDefaultTopKOptions = TopKOptions::Defaults();
+const auto kDefaultBottomKOptions = BottomKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+const FunctionDoc top_k_doc(
+ "Returns the first k elements ordered by `options.keys` in ascending
order",
+ ("This function computes the k largest elements in ascending order of the
input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "TopKOptions");
+
+const FunctionDoc bottom_k_doc(
+ "Returns the first k elements ordered by `options.keys` in descending
order",
+ ("This function computes the k smallest elements in descending order of
the input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "BottomKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ options_(options),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (options_.k > arr.length()) {
+ options_.k = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ SelectKComparator<sort_order> comparator;
+ auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto lval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (comparator(lval, rval)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ SelectKOptions options_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ options_(options),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (options_.k > chunked_array_.length()) {
+ options_.k = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() <
static_cast<size_t>(options_.k); ++iter) {
+ heap.Push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.ReplaceTop(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ SelectKOptions options_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto array = batch.GetColumnByName(key.name);
+ if (!array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(array, key.order);
+ }
+ return resolved;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const
ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > record_batch_.num_rows()) {
+ options_.k = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class TableSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key.name);
+ if (!chunked_array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(*chunked_array, key.order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For non-float types.
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t
index) {
+ const auto chunk =
first_sort_key.GetChunk<ArrayType>((int64_t)index);
+ return !chunk.IsNull();
+ });
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ auto& comparator = comparator_;
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nulls_begin;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For float types.
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ StablePartitioner partitioner;
+ uint64_t* nulls_begin;
+ if (first_sort_key.null_count == 0) {
+ nulls_begin = indices_end;
+ } else {
+ nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t
index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !chunk.IsNull();
+ });
+ }
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ uint64_t* nans_begin = partitioner(indices_begin, nulls_begin,
[&](uint64_t index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !std::isnan(chunk.Value());
+ });
+ auto& comparator = comparator_;
+ // Sort all NaNs by the second and following sort keys.
+ std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ // Sort all nulls by the second and following sort keys.
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nans_begin;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ const auto num_rows = table_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > table_.num_rows()) {
+ options_.k = table_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ auto chunk_left = first_sort_key.template GetChunk<ArrayType>(left);
+ auto chunk_right = first_sort_key.template GetChunk<ArrayType>(right);
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(value_left, value_right);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(num_rows);
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter =
+ this->PartitionNullsInternal<InType>(indices_begin, indices_end,
first_sort_key);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ uint64_t top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class SelectKthMetaFunction {
+ public:
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx)
const {
+ const SelectKOptions& select_k_options = static_cast<const
SelectKOptions&>(*options);
+ if (select_k_options.k < 0) {
+ return Status::Invalid("TopK/BottomK requires a valid `k` parameter");
Review comment:
```suggestion
return Status::Invalid("SelectK requires a nonnegative `k`, got ",
select_k_options.k);
```
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1799,711 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+const auto kDefaultTopKOptions = TopKOptions::Defaults();
+const auto kDefaultBottomKOptions = BottomKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+const FunctionDoc top_k_doc(
+ "Returns the first k elements ordered by `options.keys` in ascending
order",
+ ("This function computes the k largest elements in ascending order of the
input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "TopKOptions");
+
+const FunctionDoc bottom_k_doc(
+ "Returns the first k elements ordered by `options.keys` in descending
order",
+ ("This function computes the k smallest elements in descending order of
the input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "BottomKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ options_(options),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (options_.k > arr.length()) {
+ options_.k = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ SelectKComparator<sort_order> comparator;
+ auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto lval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (comparator(lval, rval)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ SelectKOptions options_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ options_(options),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (options_.k > chunked_array_.length()) {
+ options_.k = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() <
static_cast<size_t>(options_.k); ++iter) {
+ heap.Push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.ReplaceTop(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ SelectKOptions options_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto array = batch.GetColumnByName(key.name);
+ if (!array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(array, key.order);
+ }
+ return resolved;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const
ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > record_batch_.num_rows()) {
+ options_.k = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class TableSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<SortKey>& sort_keys, Status*
status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key.name);
+ if (!chunked_array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key.name);
+ break;
+ }
+ resolved.emplace_back(*chunked_array, key.order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For non-float types.
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t
index) {
+ const auto chunk =
first_sort_key.GetChunk<ArrayType>((int64_t)index);
+ return !chunk.IsNull();
+ });
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ auto& comparator = comparator_;
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nulls_begin;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For float types.
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ StablePartitioner partitioner;
+ uint64_t* nulls_begin;
+ if (first_sort_key.null_count == 0) {
+ nulls_begin = indices_end;
+ } else {
+ nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t
index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !chunk.IsNull();
+ });
+ }
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ uint64_t* nans_begin = partitioner(indices_begin, nulls_begin,
[&](uint64_t index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !std::isnan(chunk.Value());
+ });
+ auto& comparator = comparator_;
+ // Sort all NaNs by the second and following sort keys.
+ std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ // Sort all nulls by the second and following sort keys.
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nans_begin;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ const auto num_rows = table_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > table_.num_rows()) {
+ options_.k = table_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ auto chunk_left = first_sort_key.template GetChunk<ArrayType>(left);
+ auto chunk_right = first_sort_key.template GetChunk<ArrayType>(right);
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(value_left, value_right);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(num_rows);
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter =
+ this->PartitionNullsInternal<InType>(indices_begin, indices_end,
first_sort_key);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ uint64_t top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+class SelectKthMetaFunction {
+ public:
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx)
const {
+ const SelectKOptions& select_k_options = static_cast<const
SelectKOptions&>(*options);
+ if (select_k_options.k < 0) {
+ return Status::Invalid("TopK/BottomK requires a valid `k` parameter");
+ }
+ switch (args[0].kind()) {
+ case Datum::ARRAY: {
+ if (select_k_options.is_top_k()) {
+ return SelectKth<SortOrder::Descending>(*args[0].make_array(),
select_k_options,
+ ctx);
+ } else {
+ return SelectKth<SortOrder::Ascending>(*args[0].make_array(),
select_k_options,
+ ctx);
+ }
+ } break;
+ case Datum::CHUNKED_ARRAY: {
+ if (select_k_options.is_top_k()) {
+ return SelectKth<SortOrder::Descending>(*args[0].chunked_array(),
+ select_k_options, ctx);
+ } else {
+ return SelectKth<SortOrder::Ascending>(*args[0].chunked_array(),
+ select_k_options, ctx);
+ }
+ } break;
+ case Datum::RECORD_BATCH:
+ return SelectKth(*args[0].record_batch(), select_k_options, ctx);
+ break;
+ case Datum::TABLE:
+ return SelectKth(*args[0].table(), select_k_options, ctx);
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for sort_indices operation: "
Review comment:
```suggestion
"Unsupported types for select_k operation: "
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]