lidavidm commented on a change in pull request #11019:
URL: https://github.com/apache/arrow/pull/11019#discussion_r705496801
##########
File path: docs/source/python/api/compute.rst
##########
@@ -350,6 +350,7 @@ Sorts and partitions
partition_nth_indices
sort_indices
+ select_k
Review comment:
Please try to keep this list sorted alphabetically.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1798,627 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
Review comment:
A few things here. Why accept out_type when it casts it to UInt64Type?
Maybe just change this to `MakeMutableUInt64Array` or similar and use
`sizeof(uint64_t)`.
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,737 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Datum& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return SelectKUnstable(values, SelectKOptions::TopKDefault(k));
+ } else {
+ return SelectKUnstable(values, SelectKOptions::BottomKDefault(k));
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Datum& values,
+ const SelectKOptions& options) {
+ if (order == SortOrder::Descending) {
+ return SelectKUnstable(Datum(values), options);
+ } else {
+ return SelectKUnstable(Datum(values), options);
+ }
+}
+
+void ValidateSelectK(const Datum& datum, int64_t k, Array& select_k_indices,
+ SortOrder order, bool stable_sort = false) {
+ ASSERT_TRUE(datum.is_arraylike());
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices,
+ SortIndices(datum, SortOptions({SortKey("unused",
order)})));
+
+ if (k < datum.length()) {
Review comment:
Doesn't this mean we aren't asserting anything if k > datum.length()?
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1798,627 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ k_(options.k),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (k_ > arr.length()) {
+ k_ = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + k_;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ SelectKComparator<sort_order> comparator;
+ auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ if (cmp(x_index, heap.top())) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ int64_t k_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ k_(options.k),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (k_ > chunked_array_.length()) {
+ k_ = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<HeapItem, std::vector<HeapItem>, decltype(cmp)>;
+
+ HeapContainer heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + k_;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() < static_cast<size_t>(k_);
++iter) {
+ heap.push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.pop();
+ heap.push(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ int64_t k_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ k_(options.k),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.sort_keys)),
+ comparator_(sort_keys_) {}
+
+ Status Run() { return sort_keys_[0].type->Accept(this); }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto array = batch.GetColumnByName(key.name);
+ resolved.emplace_back(array, key.order);
+ }
+ return resolved;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const
ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (k_ > record_batch_.num_rows()) {
+ k_ = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + k_;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ int64_t k_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+};
+
+class TableSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ k_(options.k),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.sort_keys)),
+ comparator_(sort_keys_) {}
+
+ Status Run() { return sort_keys_[0].type->Accept(this); }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<SortKey>& sort_keys) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key.name);
+ resolved.emplace_back(*chunked_array, key.order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For non-float types.
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t
index) {
+ const auto chunk =
first_sort_key.GetChunk<ArrayType>((int64_t)index);
Review comment:
Please use static_cast, not a C-style cast.
##########
File path: docs/source/cpp/compute.rst
##########
@@ -1367,6 +1367,10 @@ value, but smaller than nulls.
+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
| sort_indices | Unary | Boolean, Numeric, Temporal | UInt64
| :struct:`SortOptions` | \(2) \(5) |
+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
+| select_k | Unary | Binary- and String-like | UInt64
| :struct:`SelectKOptions` | \(5) \(3) \(6) |
++-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
+| select_k | Unary | Boolean, Numeric, Temporal | UInt64
| :struct:`SelectKOptions` | \(5) \(6) |
++-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
Review comment:
Why is this row duplicated twice? We can keep all the notes together (it
should be clear that (3) only applies to binary/string types only…)
##########
File path: docs/source/cpp/compute.rst
##########
@@ -1367,6 +1367,10 @@ value, but smaller than nulls.
+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
| sort_indices | Unary | Boolean, Numeric, Temporal | UInt64
| :struct:`SortOptions` | \(2) \(5) |
+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
+| select_k | Unary | Binary- and String-like | UInt64
| :struct:`SelectKOptions` | \(5) \(3) \(6) |
++-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
+| select_k | Unary | Boolean, Numeric, Temporal | UInt64
| :struct:`SelectKOptions` | \(5) \(6) |
++-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+
Review comment:
Also, please try to keep the table sorted alphabetically.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1798,627 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
Review comment:
```suggestion
Result<std::shared_ptr<ArrayData>> MakeUInt64Array(MemoryPool* memory_pool,
int64_t length) {
auto buffer_size = length * sizeof(uint64_t);
ARROW_ASSIGN_OR_RAISE(auto data, AllocateBuffer(buffer_size, memory_pool));
return ArrayData::Make(uint64(), length, {nullptr, std::move(data)},
/*null_count=*/0);
}
```
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1784,711 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>;
+const auto kDefaultTopKOptions = SelectKOptions::TopKDefault();
+const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault();
+
+const FunctionDoc top_k_doc(
+ "Return the indices that would partition an array array, record batch or
table\n"
+ "around a pivot",
+ ("@TODO"), {"input", "k"}, "PartitionNthOptions");
+
+const FunctionDoc bottom_k_doc(
+ "Return the indices that would partition an array array, record batch or
table\n"
+ "around a pivot",
+ ("@TODO"), {"input", "k"}, "PartitionNthOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForFixedSizedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, int64_t k, const
SortOrder order,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ k_(k),
+ physical_type_(GetPhysicalType(array.type())),
+ order_(order),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
Review comment:
I suppose looking at this again we don't implement sorting for things
like timestamp type in general(?) - we may want to file a followup to expand
type support while trying to reduce the generated code.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1799,736 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>;
+const auto kDefaultTopKOptions = SelectKOptions::TopKDefault();
+const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault();
+
+const FunctionDoc top_k_doc(
+ "Returns the first k elements ordered by `options.keys` in ascending
order",
+ ("This function computes the k largest elements in ascending order of the
input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+const FunctionDoc bottom_k_doc(
+ "Returns the first k elements ordered by `options.keys` in descending
order",
+ ("This function computes the k smallest elements in descending order of
the input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForFixedSizedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ options_(options),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return VisitTypeInline(*physical_type_, this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+ Status Visit(const DataType& type) {
+ return Status::TypeError("Unsupported type for ArraySelecter: ",
type.ToString());
+ }
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (options_.k > arr.length()) {
+ options_.k = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ std::function<bool(uint64_t, uint64_t)> cmp;
+ SelectKComparator<sort_order> comparator;
+ cmp = [&arr, &comparator](uint64_t left, uint64_t right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto lval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (comparator(lval, rval)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ if (options_.keep_duplicates == true) {
+ iter = indices_begin;
+ for (; iter != end_iter; ++iter) {
+ if (*iter != heap.top()) {
+ const auto lval = GetView::LogicalValue(arr.GetView(*iter));
+ const auto rval = GetView::LogicalValue(arr.GetView(heap.top()));
+ if (lval == rval) {
+ heap.Push(*iter);
+ }
+ }
+ }
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForFixedSizedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ ARROW_ASSIGN_OR_RAISE(*output_, Take(array_,
Datum(std::move(take_indices)),
+ TakeOptions::NoBoundsCheck(), ctx_));
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ SelectKOptions options_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ options_(options),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (options_.k > chunked_array_.length()) {
+ options_.k = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() <
static_cast<size_t>(options_.k); ++iter) {
+ heap.Push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.ReplaceTop(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ if (options_.keep_duplicates == true) {
+ offset = 0;
+ for (const auto& chunk : chunks_holder) {
+ ArrayType& arr = *chunk;
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto iter = indices_begin;
+ for (; iter != indices_end; ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (x_index + offset != top_item.index + top_item.offset) {
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (xval == top_value) {
+ heap.Push(HeapItem{x_index, offset, &arr});
+ }
+ }
+ }
+ offset += chunk->length();
+ }
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForFixedSizedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.Pop();
+ --out_cbegin;
+ }
+ ARROW_ASSIGN_OR_RAISE(auto chunked_select_k,
+ Take(Datum(chunked_array_),
Datum(std::move(take_indices)),
+ TakeOptions::NoBoundsCheck(), ctx_));
+ ARROW_ASSIGN_OR_RAISE(
+ auto select_k,
+ Concatenate(chunked_select_k.chunked_array()->chunks(),
ctx_->memory_pool()));
+ *output_ = Datum(select_k);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ SelectKOptions options_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+template <SortOrder sort_order>
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.keys, options.order,
&status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<std::string>& sort_keys,
+ SortOrder order, Status* status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key_name : sort_keys) {
+ auto array = batch.GetColumnByName(key_name);
+ if (!array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key_name);
+ break;
+ }
+ resolved.emplace_back(array, order);
+ }
+ return resolved;
+ }
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const
ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > record_batch_.num_rows()) {
+ options_.k = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ if (options_.keep_duplicates == true) {
+ iter = indices_begin;
+ for (; iter != end_iter; ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (x_index != top_item) {
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ const auto& top_value = GetView::LogicalValue(arr.GetView(top_item));
+ if (xval == top_value && comparator.Equals(x_index, top_item, 1)) {
+ heap.Push(x_index);
+ }
+ }
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForFixedSizedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ ARROW_ASSIGN_OR_RAISE(*output_,
+ Take(Datum(record_batch_),
Datum(std::move(take_indices)),
+ TakeOptions::NoBoundsCheck(), ctx_));
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
+
+template <SortOrder sort_order>
+class TableSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ options_(options),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.keys, options.order,
&status_)),
+ comparator_(sort_keys_) {}
+
+ Status Run() {
+ ARROW_RETURN_NOT_OK(status_);
+ return sort_keys_[0].type->Accept(this);
+ }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<std::string>& sort_keys, SortOrder
order,
+ Status* status) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key_name : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key_name);
+ if (!chunked_array) {
+ *status = Status::Invalid("Nonexistent sort key column: ", key_name);
+ break;
+ }
+ resolved.emplace_back(*chunked_array, order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For non-float types.
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t
index) {
+ const auto chunk =
first_sort_key.GetChunk<ArrayType>((int64_t)index);
+ return !chunk.IsNull();
+ });
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ auto& comparator = comparator_;
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nulls_begin;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For float types.
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ StablePartitioner partitioner;
+ uint64_t* nulls_begin;
+ if (first_sort_key.null_count == 0) {
+ nulls_begin = indices_end;
+ } else {
+ nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t
index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !chunk.IsNull();
+ });
+ }
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ uint64_t* nans_begin = partitioner(indices_begin, nulls_begin,
[&](uint64_t index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !std::isnan(chunk.Value());
+ });
+ auto& comparator = comparator_;
+ // Sort all NaNs by the second and following sort keys.
+ std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ // Sort all nulls by the second and following sort keys.
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nans_begin;
+ }
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ const auto num_rows = table_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (options_.k > table_.num_rows()) {
+ options_.k = table_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ auto chunk_left = first_sort_key.GetChunk<ArrayType>(left);
+ auto chunk_right = first_sort_key.GetChunk<ArrayType>(right);
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(value_left, value_right);
+ };
+ arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+ std::vector<uint64_t> indices(num_rows);
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter =
+ this->PartitionNullsInternal<InType>(indices_begin, indices_end,
first_sort_key);
+ auto kth_begin = indices_begin + options_.k;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin; ++iter) {
+ heap.Push(*iter);
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ uint64_t top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.ReplaceTop(x_index);
+ }
+ }
+ if (options_.keep_duplicates == true) {
+ iter = indices_begin;
+ for (; iter != end_iter; ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (x_index != top_item) {
+ auto chunk_left = first_sort_key.GetChunk<ArrayType>(x_index);
+ auto chunk_right = first_sort_key.GetChunk<ArrayType>(top_item);
+ auto xval = chunk_left.Value();
+ auto top_value = chunk_right.Value();
+ if (xval == top_value && comparator.Equals(x_index, top_item, 1)) {
+ heap.Push(x_index);
+ }
+ }
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForFixedSizedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.Pop();
+ --out_cbegin;
+ }
+ ARROW_ASSIGN_OR_RAISE(*output_, Take(Datum(table_),
Datum(std::move(take_indices)),
+ TakeOptions::NoBoundsCheck(), ctx_));
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ SelectKOptions options_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+ Status status_;
+};
Review comment:
Ping here - we should at least file a followup. The selecters are all
very similar and I believe could be consolidated somewhat.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]