cyb70289 commented on a change in pull request #11019:
URL: https://github.com/apache/arrow/pull/11019#discussion_r704055748



##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1799,697 @@ class SortIndicesMetaFunction : public MetaFunction {
   }
 };
 
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>;
+
+const auto kDefaultSelectKOptions = SelectKOptions::Defaults();
+const auto kDefaultTopKOptions = SelectKOptions::Defaults();
+const auto kDefaultBottomKOptions = SelectKOptions::Defaults();
+
+const FunctionDoc top_k_doc(
+    "Returns the first k elements ordered by `options.keys` in ascending 
order",
+    ("This function computes the k largest elements in ascending order of the 
input\n"
+     "array, record batch or table specified in the column names 
(`options.keys`). The\n"
+     "columns that are not specified are returned as well, but not used for 
ordering.\n"
+     "Null values are considered  greater than any other value and are 
therefore sorted\n"
+     "at the end of the array.\n"
+     "For floating-point types, NaNs are considered greater than any\n"
+     "other non-null value, but smaller than null values."),
+    {"input"}, "SelectKOptions");
+
+const FunctionDoc bottom_k_doc(
+    "Returns the first k elements ordered by `options.keys` in descending 
order",
+    ("This function computes the k smallest elements in descending order of 
the input\n"
+     "array, record batch or table specified in the column names 
(`options.keys`). The\n"
+     "columns that are not specified are returned as well, but not used for 
ordering.\n"
+     "Null values are considered  greater than any other value and are 
therefore sorted\n"
+     "at the end of the array.\n"
+     "For floating-point types, NaNs are considered greater than any\n"
+     "other non-null value, but smaller than null values."),
+    {"input"}, "SelectKOptions");
+
+Result<std::shared_ptr<ArrayData>> MakeMutableArrayForNumericBasedType(
+    std::shared_ptr<DataType> out_type, int64_t length, MemoryPool* 
memory_pool) {
+  auto buffer_size = BitUtil::BytesForBits(
+      length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+  std::vector<std::shared_ptr<Buffer>> buffers(2);
+  ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size, 
memory_pool));
+  auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+  return out;
+}
+
+template <SortOrder order>
+class SelectKComparator {
+ public:
+  template <typename Type>
+  bool operator()(const Type& lval, const Type& rval);
+};
+
+template <>
+class SelectKComparator<SortOrder::Ascending> {
+ public:
+  template <typename Type>
+  bool operator()(const Type& lval, const Type& rval) {
+    return lval < rval;
+  }
+};
+
+template <>
+class SelectKComparator<SortOrder::Descending> {
+ public:
+  template <typename Type>
+  bool operator()(const Type& lval, const Type& rval) {
+    return rval < lval;
+  }
+};
+
+template <SortOrder sort_order>
+class ArraySelecter : public TypeVisitor {
+ public:
+  ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& 
options,
+                Datum* output)
+      : TypeVisitor(),
+        ctx_(ctx),
+        array_(array),
+        options_(options),
+        physical_type_(GetPhysicalType(array.type())),
+        output_(output) {}
+
+  Status Run() { return physical_type_->Accept(this); }
+
+#define VISIT(TYPE) \
+  Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); }
+
+  VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+  template <typename InType>
+  Status SelectKthInternal() {
+    using GetView = GetViewType<InType>;
+    using ArrayType = typename TypeTraits<InType>::ArrayType;
+
+    ArrayType arr(array_.data());
+    std::vector<uint64_t> indices(arr.length());
+
+    uint64_t* indices_begin = indices.data();
+    uint64_t* indices_end = indices_begin + indices.size();
+    std::iota(indices_begin, indices_end, 0);
+    if (options_.k > arr.length()) {
+      options_.k = arr.length();
+    }
+    auto end_iter = PartitionNulls<ArrayType, 
NonStablePartitioner>(indices_begin,
+                                                                    
indices_end, arr, 0);
+    auto kth_begin = indices_begin + options_.k;
+    if (kth_begin > end_iter) {
+      kth_begin = end_iter;
+    }
+    SelectKComparator<sort_order> comparator;
+    auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) {
+      const auto lval = GetView::LogicalValue(arr.GetView(left));
+      const auto rval = GetView::LogicalValue(arr.GetView(right));
+      return comparator(lval, rval);
+    };
+    arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp);
+
+    uint64_t* iter = indices_begin;
+    for (; iter != kth_begin; ++iter) {
+      heap.Push(*iter);
+    }
+

Review comment:
       Instead of implementing a heap by oursleves, can we use 
`std::priority_queue` directly here? Any special requirements a priority_queue 
cannot meet?
   
   Another small catch is this heap initialization code is O(nlogn) 
complexity(insert elements one by one). `std::priority_queue` initializes heap 
directly from a vector, I suppose it's using O(n) approach like std::make_heap.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to