lidavidm commented on a change in pull request #11019:
URL: https://github.com/apache/arrow/pull/11019#discussion_r705403591
##########
File path: python/pyarrow/tests/test_compute.py
##########
@@ -127,11 +127,12 @@ def test_option_class_equality():
pc.SetLookupOptions(value_set=pa.array([1])),
pc.SliceOptions(start=0, stop=1, step=1),
pc.SplitPatternOptions(pattern="pattern"),
+ pc.SelectKOptions(k=0, sort_keys=[("b", "ascending")]),
pc.StrptimeOptions("%Y", "s"),
pc.TrimOptions(" "),
pc.StrftimeOptions(),
]
- classes = {type(option) for option in options}
+ classes={type(option) for option in options}
Review comment:
Why all these whitespace changes?
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,774 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return SelectKUnstable(Datum(values), SelectKOptions::TopKDefault(k));
+ } else {
+ return SelectKUnstable(Datum(values), SelectKOptions::BottomKDefault(k));
+ }
+}
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return SelectKUnstable(Datum(values), SelectKOptions::TopKDefault(k));
+ } else {
+ return SelectKUnstable(Datum(values), SelectKOptions::BottomKDefault(k));
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const RecordBatch& values,
+ const SelectKOptions& options) {
+ if (order == SortOrder::Descending) {
+ return SelectKUnstable(Datum(values), options);
+ } else {
+ return SelectKUnstable(Datum(values), options);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Table& values,
+ const SelectKOptions& options) {
+ if (order == SortOrder::Descending) {
+ return SelectKUnstable(Datum(values), options);
+ } else {
+ return SelectKUnstable(Datum(values), options);
+ }
+}
+
+void ValidateSelectK(const Array& array, int64_t k, Array& select_k_indices,
+ SortOrder order, bool stable_sort = false) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+
+ if (k < array.length()) {
+ // head(k)
+ auto head_k_indices = sorted_indices->Slice(0, select_k_indices.length());
+ if (stable_sort) {
+ AssertArraysEqual(*head_k_indices, select_k_indices);
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto expected,
+ Take(array, *head_k_indices,
TakeOptions::NoBoundsCheck()));
+ ASSERT_OK_AND_ASSIGN(auto actual,
+ Take(array, select_k_indices,
TakeOptions::NoBoundsCheck()));
+ AssertArraysEqual(*expected, *actual);
+ }
+ }
+}
+
+void ValidateSelectK(const ChunkedArray& chunked_array, int64_t k,
Review comment:
Same goes here - I would expect we only need one overload.
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,774 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
Review comment:
Do we need all these overloads? Datum has implicit constructors:
https://github.com/apache/arrow/blob/bb1ef850d3a6f5b998ff4dc1f196ee9cb7c273a3/cpp/src/arrow/datum.h
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -131,8 +165,6 @@ class ARROW_EXPORT PartitionNthOptions : public
FunctionOptions {
int64_t pivot;
};
-/// @}
Review comment:
Did you mean to remove this? This tells Doxygen where a group ends.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1798,650 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+
+const FunctionDoc select_k_doc(
+ "Returns the first k elements ordered by `options.keys`",
+ ("This function computes the k elements of the input\n"
+ "array, record batch or table specified in the column names
(`options.sort_keys`).\n"
+ "The columns that are not specified are returned as well, but not used
for\n"
+ "ordering. Null values are considered greater than any other value and
are\n"
+ "therefore sorted at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions");
+
+const FunctionDoc top_k_doc(
+ "Returns the first k elements ordered by `options.keys` in ascending
order",
+ ("This function computes the k largest elements in ascending order of the
input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions::TopKDefault");
+
+const FunctionDoc bottom_k_doc(
+ "Returns the first k elements ordered by `options.keys` in descending
order",
+ ("This function computes the k smallest elements in descending order of
the input\n"
+ "array, record batch or table specified in the column names
(`options.keys`). The\n"
+ "columns that are not specified are returned as well, but not used for
ordering.\n"
+ "Null values are considered greater than any other value and are
therefore sorted\n"
+ "at the end of the array.\n"
+ "For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SelectKOptions::BottomKDefault");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+ std::shared_ptr<DataType> out_type, int64_t length, MemoryPool*
memory_pool) {
+ auto buffer_size = BitUtil::BytesForBits(
+ length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+ std::vector<std::shared_ptr<Buffer>> buffers(2);
+ ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size,
memory_pool));
+ auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+ return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return lval < rval;
+ }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ return rval < lval;
+ }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+ ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ array_(array),
+ k_(options.k),
+ physical_type_(GetPhysicalType(array.type())),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+ ArrayType arr(array_.data());
+ std::vector<uint64_t> indices(arr.length());
+
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+ if (k_ > arr.length()) {
+ k_ = arr.length();
+ }
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + k_;
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ SelectKComparator<sort_order> comparator;
+ auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ if (cmp(x_index, heap.top())) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Array& array_;
+ int64_t k_;
+ const std::shared_ptr<DataType> physical_type_;
+ Datum* output_;
+};
+
+template <typename ArrayType>
+struct TypedHeapItem {
+ uint64_t index;
+ uint64_t offset;
+ ArrayType* array;
+};
+
+template <SortOrder sort_order>
+class ChunkedArraySelecter : public TypeVisitor {
+ public:
+ ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ chunked_array_(chunked_array),
+ physical_type_(GetPhysicalType(chunked_array.type())),
+ physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)),
+ k_(options.k),
+ ctx_(ctx),
+ output_(output) {}
+
+ Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ template <typename InType>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ using HeapItem = TypedHeapItem<ArrayType>;
+
+ const auto num_chunks = chunked_array_.num_chunks();
+ if (num_chunks == 0) {
+ return Status::OK();
+ }
+ if (k_ > chunked_array_.length()) {
+ k_ = chunked_array_.length();
+ }
+ std::function<bool(const HeapItem&, const HeapItem&)> cmp;
+ SelectKComparator<sort_order> comparator;
+
+ cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool {
+ const auto lval = GetView::LogicalValue(left.array->GetView(left.index));
+ const auto rval =
GetView::LogicalValue(right.array->GetView(right.index));
+ return comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<HeapItem, std::vector<HeapItem>, decltype(cmp)>;
+
+ HeapContainer heap(cmp);
+ std::vector<std::shared_ptr<ArrayType>> chunks_holder;
+ uint64_t offset = 0;
+ for (const auto& chunk : physical_chunks_) {
+ if (chunk->length() == 0) continue;
+ chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data()));
+ ArrayType& arr = *chunks_holder[chunks_holder.size() - 1];
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(
+ indices_begin, indices_end, arr, 0);
+ auto kth_begin = indices_begin + k_;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ uint64_t* iter = indices_begin;
+ for (; iter != kth_begin && heap.size() < static_cast<size_t>(k_);
++iter) {
+ heap.push(HeapItem{*iter, offset, &arr});
+ }
+ for (; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ const auto& xval = GetView::LogicalValue(arr.GetView(x_index));
+ auto top_item = heap.top();
+ const auto& top_value =
+ GetView::LogicalValue(top_item.array->GetView(top_item.index));
+ if (comparator(xval, top_value)) {
+ heap.pop();
+ heap.push(HeapItem{x_index, offset, &arr});
+ }
+ }
+ offset += chunk->length();
+ }
+
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ auto top_item = heap.top();
+ *out_cbegin = top_item.index + top_item.offset;
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ const ChunkedArray& chunked_array_;
+ const std::shared_ptr<DataType> physical_type_;
+ const ArrayVector physical_chunks_;
+ int64_t k_;
+ ExecContext* ctx_;
+ Datum* output_;
+};
+
+class RecordBatchSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch,
+ const SelectKOptions& options, Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ record_batch_(record_batch),
+ k_(options.k),
+ output_(output),
+ sort_keys_(ResolveSortKeys(record_batch, options.sort_keys)),
+ comparator_(sort_keys_) {}
+
+ Status Run() { return sort_keys_[0].type->Accept(this); }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto array = batch.GetColumnByName(key.name);
+ resolved.emplace_back(array, key.order);
+ }
+ return resolved;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using GetView = GetViewType<InType>;
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+ const ArrayType& arr = checked_cast<const
ArrayType&>(first_sort_key.array);
+
+ const auto num_rows = record_batch_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (k_ > record_batch_.num_rows()) {
+ k_ = record_batch_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ if (lval == rval) {
+ // If the left value equals to the right value,
+ // we need to compare the second and following
+ // sort keys.
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(lval, rval);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+
+ std::vector<uint64_t> indices(arr.length());
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter = PartitionNulls<ArrayType,
NonStablePartitioner>(indices_begin,
+
indices_end, arr, 0);
+ auto kth_begin = indices_begin + k_;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ auto top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const RecordBatch& record_batch_;
+ int64_t k_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+};
+
+class TableSelecter : public TypeVisitor {
+ private:
+ using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey;
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+ TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions&
options,
+ Datum* output)
+ : TypeVisitor(),
+ ctx_(ctx),
+ table_(table),
+ k_(options.k),
+ output_(output),
+ sort_keys_(ResolveSortKeys(table, options.sort_keys)),
+ comparator_(sort_keys_) {}
+
+ Status Run() { return sort_keys_[0].type->Accept(this); }
+
+ protected:
+#define VISIT(TYPE) \
+ Status Visit(const TYPE& type) { \
+ if (sort_keys_[0].order == SortOrder::Descending) \
+ return SelectKthInternal<TYPE, SortOrder::Descending>(); \
+ return SelectKthInternal<TYPE, SortOrder::Ascending>(); \
+ }
+ VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ static std::vector<ResolvedSortKey> ResolveSortKeys(
+ const Table& table, const std::vector<SortKey>& sort_keys) {
+ std::vector<ResolvedSortKey> resolved;
+ for (const auto& key : sort_keys) {
+ auto chunked_array = table.GetColumnByName(key.name);
+ resolved.emplace_back(*chunked_array, key.order);
+ }
+ return resolved;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For non-float types.
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t
index) {
+ const auto chunk =
first_sort_key.GetChunk<ArrayType>((int64_t)index);
+ return !chunk.IsNull();
+ });
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ auto& comparator = comparator_;
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nulls_begin;
+ }
+
+ // Behaves like PatitionNulls() but this supports multiple sort keys.
+ //
+ // For float types.
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+ uint64_t* indices_begin, uint64_t* indices_end,
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ StablePartitioner partitioner;
+ uint64_t* nulls_begin;
+ if (first_sort_key.null_count == 0) {
+ nulls_begin = indices_end;
+ } else {
+ nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t
index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !chunk.IsNull();
+ });
+ }
+ DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count);
+ uint64_t* nans_begin = partitioner(indices_begin, nulls_begin,
[&](uint64_t index) {
+ const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+ return !std::isnan(chunk.Value());
+ });
+ auto& comparator = comparator_;
+ // Sort all NaNs by the second and following sort keys.
+ std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ // Sort all nulls by the second and following sort keys.
+ std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t
right) {
+ return comparator.Compare(left, right, 1);
+ });
+ return nans_begin;
+ }
+
+ template <typename InType, SortOrder sort_order>
+ Status SelectKthInternal() {
+ using ArrayType = typename TypeTraits<InType>::ArrayType;
+ auto& comparator = comparator_;
+ const auto& first_sort_key = sort_keys_[0];
+
+ const auto num_rows = table_.num_rows();
+ if (num_rows == 0) {
+ return Status::OK();
+ }
+ if (k_ > table_.num_rows()) {
+ k_ = table_.num_rows();
+ }
+ std::function<bool(const uint64_t&, const uint64_t&)> cmp;
+ SelectKComparator<sort_order> select_k_comparator;
+ cmp = [&](const uint64_t& left, const uint64_t& right) -> bool {
+ auto chunk_left = first_sort_key.template GetChunk<ArrayType>(left);
+ auto chunk_right = first_sort_key.template GetChunk<ArrayType>(right);
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
+ if (value_left == value_right) {
+ return comparator.Compare(left, right, 1);
+ }
+ return select_k_comparator(value_left, value_right);
+ };
+ using HeapContainer =
+ std::priority_queue<uint64_t, std::vector<uint64_t>, decltype(cmp)>;
+
+ std::vector<uint64_t> indices(num_rows);
+ uint64_t* indices_begin = indices.data();
+ uint64_t* indices_end = indices_begin + indices.size();
+ std::iota(indices_begin, indices_end, 0);
+
+ auto end_iter =
+ this->PartitionNullsInternal<InType>(indices_begin, indices_end,
first_sort_key);
+ auto kth_begin = indices_begin + k_;
+
+ if (kth_begin > end_iter) {
+ kth_begin = end_iter;
+ }
+ HeapContainer heap(indices_begin, kth_begin, cmp);
+ for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) {
+ uint64_t x_index = *iter;
+ uint64_t top_item = heap.top();
+ if (cmp(x_index, top_item)) {
+ heap.pop();
+ heap.push(x_index);
+ }
+ }
+ int64_t out_size = static_cast<int64_t>(heap.size());
+ ARROW_ASSIGN_OR_RAISE(
+ auto take_indices,
+ MakeMutableArrayForNumericBasedType(uint64(), out_size,
ctx_->memory_pool()));
+ auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size
- 1;
+ while (heap.size() > 0) {
+ *out_cbegin = heap.top();
+ heap.pop();
+ --out_cbegin;
+ }
+ *output_ = Datum(take_indices);
+ return Status::OK();
+ }
+
+ ExecContext* ctx_;
+ const Table& table_;
+ int64_t k_;
+ Datum* output_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ Comparator comparator_;
+};
+
+template <typename ArrowContainer>
+static Status CheckConsistency(const ArrowContainer& container,
Review comment:
You can avoid templating this by taking `const Schema& schema` instead.
You may also want to use `FieldRef::FindOne` since this will also tell you if
the field can be unambiguously referenced.
https://github.com/apache/arrow/blob/bb1ef850d3a6f5b998ff4dc1f196ee9cb7c273a3/cpp/src/arrow/type.h#L1631
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]