lidavidm commented on a change in pull request #11019:
URL: https://github.com/apache/arrow/pull/11019#discussion_r700206299
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,25 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+/// \brief Partitioning options for NthToIndices
+class ARROW_EXPORT SelectKOptions : public FunctionOptions {
+ public:
+ explicit SelectKOptions(int64_t pivot = 0, std::vector<std::string> keys =
{},
+ std::string keep = "first",
Review comment:
We should use an enum instead of a string for modes like this.
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,25 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+/// \brief Partitioning options for NthToIndices
+class ARROW_EXPORT SelectKOptions : public FunctionOptions {
+ public:
+ explicit SelectKOptions(int64_t pivot = 0, std::vector<std::string> keys =
{},
Review comment:
Here it's called pivot instead of k - is that intentional?
##########
File path: cpp/src/arrow/compute/api_vector.cc
##########
@@ -111,6 +111,10 @@ static auto kSortOptionsType =
GetFunctionOptionsType<SortOptions>(DataMember("sort_keys",
&SortOptions::sort_keys));
static auto kPartitionNthOptionsType =
GetFunctionOptionsType<PartitionNthOptions>(
DataMember("pivot", &PartitionNthOptions::pivot));
+static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
+ DataMember("k", &SelectKOptions::k), DataMember("keys",
&SelectKOptions::keys),
+ DataMember("order", &SelectKOptions::order),
+ DataMember("keep", &SelectKOptions::keep));
Review comment:
nit: it'd be good to add a couple instances of the options to the test
in function_test.cc:
https://github.com/apache/arrow/blob/3da09e51e3f94ca427915c98ac59d72701f728bc/cpp/src/arrow/compute/function_test.cc#L112
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A Heap class, is a simple wrapper to make heap operation simpler.
+// This class is immutable by design
Review comment:
The heap is immutable?
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A Heap class, is a simple wrapper to make heap operation simpler.
+// This class is immutable by design
+template <typename T, typename Compare = std::less<T>>
+class ARROW_EXPORT Heap {
+ public:
+ explicit Heap() : values_(), comp_() {}
Review comment:
note there are some lints for this file:
https://github.com/apache/arrow/pull/11019/checks?check_run_id=3468575773
```
/arrow/cpp/src/arrow/util/heap.h:33: Zero-parameter constructors should
not be marked explicit. [runtime/explicit] [5]
/arrow/cpp/src/arrow/util/heap.h:76: Could not find a newline character at
the end of the file. [whitespace/ending_newline] [5]
```
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
Review comment:
use std::isnan?
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
+ ASSERT_TRUE(lval == rval);
+ }
+ }
+ }
+ template <SortOrder order>
+ void AssertSelectKArray(const std::shared_ptr<Array> values, int n) {
+ std::shared_ptr<Array> select_k;
+ ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n));
+ ASSERT_EQ(select_k->data()->null_count, 0);
+ ValidateOutput(*select_k);
+ Validate(*checked_pointer_cast<ArrayType>(values), n,
+ *checked_pointer_cast<ArrayType>(select_k), order);
+ }
+
+ void AssertTopKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+ void AssertBottomKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+
+ void AssertSelectKJson(const std::string& values, int n) {
+ AssertTopKArray(ArrayFromJSON(GetType(), values), n);
+ AssertBottomKArray(ArrayFromJSON(GetType(), values), n);
+ }
+
+ virtual std::shared_ptr<DataType> GetType() = 0;
+};
+
+template <typename ArrowType>
+class TestSelectK : public TestSelectKBase<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> GetType() override { return
TypeToDataType<ArrowType>(); }
+};
+
+template <typename ArrowType>
+class TestSelectKForReal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForIntegral : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForBool : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types<BooleanType>);
+
+template <typename ArrowType>
+class TestSelectKForTemporal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForDecimal : public TestSelectKBase<ArrowType> {
+ std::shared_ptr<DataType> GetType() override {
+ return std::make_shared<ArrowType>(5, 2);
+ }
+};
+TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForStrings : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types<StringType>);
+
+TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) {
+ auto input = ArrayFromJSON(this->GetType(), "[null, 1, 3.3, null, 2, 5.3]");
+ ASSERT_RAISES(Invalid, CallFunction("top_k", {input}));
+}
+
+TYPED_TEST(TestSelectKForReal, Real) {
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6);
+
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4);
+}
+
+TYPED_TEST(TestSelectKForIntegral, Integral) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+
+ this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5);
+}
+
+TYPED_TEST(TestSelectKForBool, Bool) {
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 0);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 2);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 5);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 6);
+}
+
+TYPED_TEST(TestSelectKForTemporal, Temporal) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+}
+
+TYPED_TEST(TestSelectKForDecimal, Decimal) {
+ const std::string values = R"(["123.45", null, "-123.45", "456.78",
"-456.78"])";
+ this->AssertSelectKJson(values, 0);
+ this->AssertSelectKJson(values, 2);
+ this->AssertSelectKJson(values, 4);
+ this->AssertSelectKJson(values, 5);
+}
+
+TYPED_TEST(TestSelectKForStrings, Strings) {
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 0);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 2);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 5);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 6);
+}
+
+/*struct TestSelectKWithKeepParam : public ::testing::Test {
+ void Check(const std::shared_ptr<DataType>& type, const std::string& values,
int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, const
std::shared_ptr<Array>& values,
+ int64_t k, const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TopK(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type, const std::string&
values,
+ int64_t k, std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, TopK(*ArrayFromJSON(type, values), k));
+ PrettyPrint(**out, {}, &std::cerr);
+ return Status::OK();
+ }
+};
+
+TEST_F(TestSelectKWithKeepParam, Integral) {
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 0, "[]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 2, "[5, 3]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 5, "[5, 3, 2, 1]");
+ this->Check(int32(), "[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]", 5, "[4, 4, 4, 4, 3]");
+}*/
+
+template <typename ArrowType>
+class TestSelectKRandom : public TestSelectKBase<ArrowType> {
+ public:
+ std::shared_ptr<DataType> GetType() override {
+ EXPECT_TRUE(0) << "shouldn't be used";
+ return nullptr;
+ }
+};
+
+using SelectKableTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type,
Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType,
Decimal128Type,
+ StringType>;
+
+class RandomImpl {
+ protected:
+ random::RandomArrayGenerator generator_;
+ std::shared_ptr<DataType> type_;
+
+ explicit RandomImpl(random::SeedType seed, std::shared_ptr<DataType> type)
+ : generator_(seed), type_(std::move(type)) {}
+
+ public:
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
+ return generator_.ArrayOf(type_, count, null_prob);
+ }
+};
+
+template <typename ArrowType>
+class Random : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+};
+
+template <>
+class Random<FloatType> : public RandomImpl {
+ using CType = float;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float32(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<DoubleType> : public RandomImpl {
+ using CType = double;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float64(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<Decimal128Type> : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed,
+ std::shared_ptr<DataType> type = decimal128(18, 5))
+ : RandomImpl(seed, std::move(type)) {}
+};
+
+template <typename ArrowType>
+class RandomRange : public RandomImpl {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ public:
+ explicit RandomRange(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, int range, double null_prob)
{
+ CType min = std::numeric_limits<CType>::min();
+ CType max = min + range;
+ if (sizeof(CType) < 4 && (range + min) >
std::numeric_limits<CType>::max()) {
+ max = std::numeric_limits<CType>::max();
+ }
+ return generator_.Numeric<ArrowType>(count, min, max, null_prob);
+ }
+};
+
+TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes);
+
+TYPED_TEST(TestSelectKRandom, RandomValues) {
+ Random<TypeParam> rand(0x61549225);
+ int length = 100;
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Try n from 0 to out of bound
+ for (int n = 0; n <= length; ++n) {
+ auto array = rand.Generate(length, null_probability);
+ this->AssertTopKArray(array, n);
+ this->AssertBottomKArray(array, n);
+ }
+ }
+}
+
+template <SortOrder order>
+struct TestSelectKWithChunkedArray : public ::testing::Test {
+ TestSelectKWithChunkedArray()
+ : sizes_({0, 1, 2, 4, 16, 31, 1234}),
+ null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {}
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<ChunkedArray>& values, int64_t k,
+ const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, SelectK<order>(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, SelectK<order>(*(ChunkedArrayFromJSON(type,
values)), k));
+ PrettyPrint(**out, {}, &std::cerr);
+ return Status::OK();
+ }
+ std::vector<int32_t> sizes_;
+ std::vector<double> null_probabilities_;
+};
+
+struct TestTopKWithChunkedArray
Review comment:
Why not make these checks above in the typed tests?
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
+ ASSERT_TRUE(lval == rval);
+ }
+ }
+ }
+ template <SortOrder order>
+ void AssertSelectKArray(const std::shared_ptr<Array> values, int n) {
+ std::shared_ptr<Array> select_k;
+ ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n));
+ ASSERT_EQ(select_k->data()->null_count, 0);
+ ValidateOutput(*select_k);
+ Validate(*checked_pointer_cast<ArrayType>(values), n,
+ *checked_pointer_cast<ArrayType>(select_k), order);
+ }
+
+ void AssertTopKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+ void AssertBottomKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+
+ void AssertSelectKJson(const std::string& values, int n) {
+ AssertTopKArray(ArrayFromJSON(GetType(), values), n);
+ AssertBottomKArray(ArrayFromJSON(GetType(), values), n);
+ }
+
+ virtual std::shared_ptr<DataType> GetType() = 0;
+};
+
+template <typename ArrowType>
+class TestSelectK : public TestSelectKBase<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> GetType() override { return
TypeToDataType<ArrowType>(); }
+};
+
+template <typename ArrowType>
+class TestSelectKForReal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForIntegral : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForBool : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types<BooleanType>);
+
+template <typename ArrowType>
+class TestSelectKForTemporal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForDecimal : public TestSelectKBase<ArrowType> {
+ std::shared_ptr<DataType> GetType() override {
+ return std::make_shared<ArrowType>(5, 2);
+ }
+};
+TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForStrings : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types<StringType>);
+
+TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) {
+ auto input = ArrayFromJSON(this->GetType(), "[null, 1, 3.3, null, 2, 5.3]");
+ ASSERT_RAISES(Invalid, CallFunction("top_k", {input}));
+}
+
+TYPED_TEST(TestSelectKForReal, Real) {
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6);
+
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4);
+}
+
+TYPED_TEST(TestSelectKForIntegral, Integral) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+
+ this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5);
+}
+
+TYPED_TEST(TestSelectKForBool, Bool) {
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 0);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 2);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 5);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 6);
+}
+
+TYPED_TEST(TestSelectKForTemporal, Temporal) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+}
+
+TYPED_TEST(TestSelectKForDecimal, Decimal) {
+ const std::string values = R"(["123.45", null, "-123.45", "456.78",
"-456.78"])";
+ this->AssertSelectKJson(values, 0);
+ this->AssertSelectKJson(values, 2);
+ this->AssertSelectKJson(values, 4);
+ this->AssertSelectKJson(values, 5);
+}
+
+TYPED_TEST(TestSelectKForStrings, Strings) {
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 0);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 2);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 5);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 6);
+}
+
+/*struct TestSelectKWithKeepParam : public ::testing::Test {
+ void Check(const std::shared_ptr<DataType>& type, const std::string& values,
int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, const
std::shared_ptr<Array>& values,
+ int64_t k, const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TopK(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type, const std::string&
values,
+ int64_t k, std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, TopK(*ArrayFromJSON(type, values), k));
+ PrettyPrint(**out, {}, &std::cerr);
+ return Status::OK();
+ }
+};
+
+TEST_F(TestSelectKWithKeepParam, Integral) {
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 0, "[]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 2, "[5, 3]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 5, "[5, 3, 2, 1]");
+ this->Check(int32(), "[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]", 5, "[4, 4, 4, 4, 3]");
+}*/
+
+template <typename ArrowType>
+class TestSelectKRandom : public TestSelectKBase<ArrowType> {
+ public:
+ std::shared_ptr<DataType> GetType() override {
+ EXPECT_TRUE(0) << "shouldn't be used";
+ return nullptr;
+ }
+};
+
+using SelectKableTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type,
Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType,
Decimal128Type,
+ StringType>;
+
+class RandomImpl {
+ protected:
+ random::RandomArrayGenerator generator_;
+ std::shared_ptr<DataType> type_;
+
+ explicit RandomImpl(random::SeedType seed, std::shared_ptr<DataType> type)
+ : generator_(seed), type_(std::move(type)) {}
+
+ public:
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
+ return generator_.ArrayOf(type_, count, null_prob);
+ }
+};
+
+template <typename ArrowType>
+class Random : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+};
+
+template <>
+class Random<FloatType> : public RandomImpl {
+ using CType = float;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float32(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<DoubleType> : public RandomImpl {
+ using CType = double;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float64(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<Decimal128Type> : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed,
+ std::shared_ptr<DataType> type = decimal128(18, 5))
+ : RandomImpl(seed, std::move(type)) {}
+};
+
+template <typename ArrowType>
+class RandomRange : public RandomImpl {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ public:
+ explicit RandomRange(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, int range, double null_prob)
{
+ CType min = std::numeric_limits<CType>::min();
+ CType max = min + range;
+ if (sizeof(CType) < 4 && (range + min) >
std::numeric_limits<CType>::max()) {
+ max = std::numeric_limits<CType>::max();
+ }
+ return generator_.Numeric<ArrowType>(count, min, max, null_prob);
+ }
+};
+
+TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes);
+
+TYPED_TEST(TestSelectKRandom, RandomValues) {
+ Random<TypeParam> rand(0x61549225);
+ int length = 100;
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Try n from 0 to out of bound
+ for (int n = 0; n <= length; ++n) {
+ auto array = rand.Generate(length, null_probability);
+ this->AssertTopKArray(array, n);
+ this->AssertBottomKArray(array, n);
+ }
+ }
+}
+
+template <SortOrder order>
+struct TestSelectKWithChunkedArray : public ::testing::Test {
+ TestSelectKWithChunkedArray()
+ : sizes_({0, 1, 2, 4, 16, 31, 1234}),
+ null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {}
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<ChunkedArray>& values, int64_t k,
+ const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, SelectK<order>(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, SelectK<order>(*(ChunkedArrayFromJSON(type,
values)), k));
+ PrettyPrint(**out, {}, &std::cerr);
+ return Status::OK();
+ }
+ std::vector<int32_t> sizes_;
+ std::vector<double> null_probabilities_;
+};
+
+struct TestTopKWithChunkedArray
Review comment:
Also, you can probably make one typed test suite that covers all the
numeric and most of the temporal types in one go, since the temporal types
generally are numeric underneath.
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -252,6 +271,40 @@ ARROW_EXPORT
Result<std::shared_ptr<Array>> NthToIndices(const Array& values, int64_t n,
ExecContext* ctx = NULLPTR);
+/// @TODO
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> TopK(const Array& values, int64_t k,
+ const std::string& keep = "first",
+ ExecContext* ctx = NULLPTR);
+
+/// @TODO
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> TopK(const ChunkedArray& values, int64_t k,
+ const std::string& keep = "first",
+ ExecContext* ctx = NULLPTR);
+
+/// @TODO
+ARROW_EXPORT
+Result<Datum> TopK(const Datum& datum, int64_t k, SelectKOptions options,
+ ExecContext* ctx = NULLPTR);
+
+/// @TODO
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> BottomK(const Array& values, int64_t k,
+ const std::string& keep = "first",
+ ExecContext* ctx = NULLPTR);
+
+/// @TODO
+ARROW_EXPORT
+Result<std::shared_ptr<Array>> BottomK(const ChunkedArray& values, int64_t k,
+ const std::string& keep = "first",
+ ExecContext* ctx = NULLPTR);
+
+/// @TODO
+ARROW_EXPORT
+Result<Datum> BottomK(const Datum& datum, int64_t k, SelectKOptions options,
Review comment:
I'm not sure we need all these overloads in C++. They may instead be
better suited as static factory methods on the options (SelectKOptions::BottomK
and such).
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,25 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+/// \brief Partitioning options for NthToIndices
+class ARROW_EXPORT SelectKOptions : public FunctionOptions {
+ public:
+ explicit SelectKOptions(int64_t pivot = 0, std::vector<std::string> keys =
{},
+ std::string keep = "first",
+ SortOrder order = SortOrder::Ascending);
+ constexpr static char const kTypeName[] = "SelectKOptions";
+ static SelectKOptions TopKDefault() {
+ return SelectKOptions{0, {}, "first", SortOrder::Descending};
+ }
+ static SelectKOptions BottomKDefault() {
+ return SelectKOptions{0, {}, "first", SortOrder::Ascending};
+ }
+ int64_t k;
+ std::vector<std::string> keys;
+ std::string keep;
+ SortOrder order;
Review comment:
nit: we should have docstrings for these fields.
##########
File path: cpp/src/arrow/testing/gtest_util.h
##########
@@ -157,8 +157,7 @@ using NumericArrowTypes =
using RealArrowTypes = ::testing::Types<FloatType, DoubleType>;
-using IntegralArrowTypes = ::testing::Types<UInt8Type, UInt16Type, UInt32Type,
UInt64Type,
- Int8Type, Int16Type, Int32Type,
Int64Type>;
Review comment:
It seems this is WIP but please don't forget to revert this.
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
Review comment:
Can we make this heap_internal.h to make it clear it's not part of our
public API?
##########
File path: cpp/src/arrow/compute/api_vector.h
##########
@@ -120,6 +120,25 @@ class ARROW_EXPORT SortOptions : public FunctionOptions {
std::vector<SortKey> sort_keys;
};
+/// \brief Partitioning options for NthToIndices
Review comment:
nit: this comment should be updated.
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A Heap class, is a simple wrapper to make heap operation simpler.
+// This class is immutable by design
+template <typename T, typename Compare = std::less<T>>
+class ARROW_EXPORT Heap {
Review comment:
We should have direct tests of the heap itself.
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A Heap class, is a simple wrapper to make heap operation simpler.
+// This class is immutable by design
+template <typename T, typename Compare = std::less<T>>
+class ARROW_EXPORT Heap {
+ public:
+ explicit Heap() : values_(), comp_() {}
+ explicit Heap(const Compare& compare) : values_(), comp_(compare) {}
+
+ Heap(Heap&&) = default;
+ Heap& operator=(Heap&&) = default;
+
+ T* Data() { return values_.data(); }
+
+ // const T& Top() const { return values_.front(); }
+
+ T Top() const { return values_.front(); }
+
+ bool Empty() const { return values_.empty(); }
+
+ size_t Size() const { return values_.size(); }
+
+ void Push(const T& value) {
+ values_.push_back(value);
+ std::push_heap(values_.begin(), values_.end(), comp_);
+ }
+
+ void Pop() {
+ std::pop_heap(values_.begin(), values_.end(), comp_);
+ values_.pop_back();
+ }
+
+ void ReplaceTop(const T& value) {
+ std::pop_heap(values_.begin(), values_.end(), comp_);
+ values_.back() = value;
+ std::push_heap(values_.begin(), values_.end(), comp_);
+ }
+
+ void SetComparator(const Compare& comp) { comp_ = comp; }
Review comment:
Is this used?
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
Review comment:
Instead of GetLogicalValue and such, you could slice the arrays and use
AssertArraysEqual/AssertArraysApproxEqual. With the `verbose == true` parameter
this will also print out the mismatched values on failure.
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A Heap class, is a simple wrapper to make heap operation simpler.
+// This class is immutable by design
+template <typename T, typename Compare = std::less<T>>
+class ARROW_EXPORT Heap {
+ public:
+ explicit Heap() : values_(), comp_() {}
+ explicit Heap(const Compare& compare) : values_(), comp_(compare) {}
+
+ Heap(Heap&&) = default;
+ Heap& operator=(Heap&&) = default;
Review comment:
You might enjoy ARROW_DEFAULT_MOVE_AND_ASSIGN.
https://github.com/apache/arrow/blob/3da09e51e3f94ca427915c98ac59d72701f728bc/cpp/src/arrow/util/macros.h#L34-L36
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A Heap class, is a simple wrapper to make heap operation simpler.
+// This class is immutable by design
+template <typename T, typename Compare = std::less<T>>
+class ARROW_EXPORT Heap {
+ public:
+ explicit Heap() : values_(), comp_() {}
+ explicit Heap(const Compare& compare) : values_(), comp_(compare) {}
+
+ Heap(Heap&&) = default;
+ Heap& operator=(Heap&&) = default;
+
+ T* Data() { return values_.data(); }
+
+ // const T& Top() const { return values_.front(); }
+
+ T Top() const { return values_.front(); }
+
+ bool Empty() const { return values_.empty(); }
+
+ size_t Size() const { return values_.size(); }
Review comment:
nit: for const getter methods like this, they can be named lowercase
still. size() is often left as size() and not Size() in our codebase, and
top()/empty()/data() could also follow that.
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
Review comment:
You might want to use default_type_instance:
https://github.com/apache/arrow/blob/3da09e51e3f94ca427915c98ac59d72701f728bc/cpp/src/arrow/compute/kernels/test_util.h#L123-L144
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
+ ASSERT_TRUE(lval == rval);
+ }
+ }
+ }
+ template <SortOrder order>
+ void AssertSelectKArray(const std::shared_ptr<Array> values, int n) {
+ std::shared_ptr<Array> select_k;
+ ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n));
+ ASSERT_EQ(select_k->data()->null_count, 0);
+ ValidateOutput(*select_k);
+ Validate(*checked_pointer_cast<ArrayType>(values), n,
+ *checked_pointer_cast<ArrayType>(select_k), order);
+ }
+
+ void AssertTopKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+ void AssertBottomKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+
+ void AssertSelectKJson(const std::string& values, int n) {
+ AssertTopKArray(ArrayFromJSON(GetType(), values), n);
+ AssertBottomKArray(ArrayFromJSON(GetType(), values), n);
+ }
+
+ virtual std::shared_ptr<DataType> GetType() = 0;
Review comment:
nit: most other such tests in our codebase call this type_singleton
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
+ ASSERT_TRUE(lval == rval);
+ }
+ }
+ }
+ template <SortOrder order>
+ void AssertSelectKArray(const std::shared_ptr<Array> values, int n) {
+ std::shared_ptr<Array> select_k;
+ ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n));
Review comment:
Why is everything templated instead of passing it as a parameter?
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
Review comment:
That would also let you get rid of various casts and such.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -156,7 +161,7 @@ struct ChunkedArrayResolver {
}
int64_t num_chunks_;
- const Array* const* chunks_;
+ const std::vector<const Array*> chunks_;
Review comment:
Was there an issue with the original (chunks_ getting invalidated,
presumably?)?
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
Review comment:
That way it won't get installed with the rest of the headers, so it'll
be usable only within Arrow itself.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1784,711 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>;
+const auto kDefaultTopKOptions = SelectKOptions::TopKDefault();
+const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault();
+
+const FunctionDoc top_k_doc(
+ "Return the indices that would partition an array array, record batch or
table\n"
+ "around a pivot",
+ ("@TODO"), {"input", "k"}, "PartitionNthOptions");
Review comment:
I realize these are TODOs but don't forget to update these :)
##########
File path: cpp/src/arrow/util/heap.h
##########
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace internal {
+
+// A Heap class, is a simple wrapper to make heap operation simpler.
+// This class is immutable by design
+template <typename T, typename Compare = std::less<T>>
+class ARROW_EXPORT Heap {
+ public:
+ explicit Heap() : values_(), comp_() {}
Review comment:
You can use docker to run linting locally:
https://arrow.apache.org/docs/developers/cpp/development.html#code-style-linting-and-ci
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
+ ASSERT_TRUE(lval == rval);
+ }
+ }
+ }
+ template <SortOrder order>
+ void AssertSelectKArray(const std::shared_ptr<Array> values, int n) {
+ std::shared_ptr<Array> select_k;
+ ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n));
+ ASSERT_EQ(select_k->data()->null_count, 0);
+ ValidateOutput(*select_k);
+ Validate(*checked_pointer_cast<ArrayType>(values), n,
+ *checked_pointer_cast<ArrayType>(select_k), order);
+ }
+
+ void AssertTopKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+ void AssertBottomKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+
+ void AssertSelectKJson(const std::string& values, int n) {
+ AssertTopKArray(ArrayFromJSON(GetType(), values), n);
+ AssertBottomKArray(ArrayFromJSON(GetType(), values), n);
+ }
+
+ virtual std::shared_ptr<DataType> GetType() = 0;
+};
+
+template <typename ArrowType>
+class TestSelectK : public TestSelectKBase<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> GetType() override { return
TypeToDataType<ArrowType>(); }
+};
+
+template <typename ArrowType>
+class TestSelectKForReal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForIntegral : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForBool : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types<BooleanType>);
+
+template <typename ArrowType>
+class TestSelectKForTemporal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForDecimal : public TestSelectKBase<ArrowType> {
+ std::shared_ptr<DataType> GetType() override {
+ return std::make_shared<ArrowType>(5, 2);
+ }
+};
+TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForStrings : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types<StringType>);
+
+TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) {
+ auto input = ArrayFromJSON(this->GetType(), "[null, 1, 3.3, null, 2, 5.3]");
+ ASSERT_RAISES(Invalid, CallFunction("top_k", {input}));
+}
+
+TYPED_TEST(TestSelectKForReal, Real) {
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6);
+
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4);
+}
+
+TYPED_TEST(TestSelectKForIntegral, Integral) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+
+ this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5);
+}
+
+TYPED_TEST(TestSelectKForBool, Bool) {
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 0);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 2);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 5);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 6);
+}
+
+TYPED_TEST(TestSelectKForTemporal, Temporal) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+}
+
+TYPED_TEST(TestSelectKForDecimal, Decimal) {
+ const std::string values = R"(["123.45", null, "-123.45", "456.78",
"-456.78"])";
+ this->AssertSelectKJson(values, 0);
+ this->AssertSelectKJson(values, 2);
+ this->AssertSelectKJson(values, 4);
+ this->AssertSelectKJson(values, 5);
+}
+
+TYPED_TEST(TestSelectKForStrings, Strings) {
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 0);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 2);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 5);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 6);
+}
+
+/*struct TestSelectKWithKeepParam : public ::testing::Test {
+ void Check(const std::shared_ptr<DataType>& type, const std::string& values,
int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, const
std::shared_ptr<Array>& values,
+ int64_t k, const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TopK(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type, const std::string&
values,
+ int64_t k, std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, TopK(*ArrayFromJSON(type, values), k));
+ PrettyPrint(**out, {}, &std::cerr);
+ return Status::OK();
+ }
+};
+
+TEST_F(TestSelectKWithKeepParam, Integral) {
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 0, "[]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 2, "[5, 3]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 5, "[5, 3, 2, 1]");
+ this->Check(int32(), "[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]", 5, "[4, 4, 4, 4, 3]");
+}*/
+
+template <typename ArrowType>
+class TestSelectKRandom : public TestSelectKBase<ArrowType> {
+ public:
+ std::shared_ptr<DataType> GetType() override {
+ EXPECT_TRUE(0) << "shouldn't be used";
+ return nullptr;
+ }
+};
+
+using SelectKableTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type,
Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType,
Decimal128Type,
+ StringType>;
+
+class RandomImpl {
+ protected:
+ random::RandomArrayGenerator generator_;
+ std::shared_ptr<DataType> type_;
+
+ explicit RandomImpl(random::SeedType seed, std::shared_ptr<DataType> type)
+ : generator_(seed), type_(std::move(type)) {}
+
+ public:
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
+ return generator_.ArrayOf(type_, count, null_prob);
+ }
+};
+
+template <typename ArrowType>
+class Random : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+};
+
+template <>
+class Random<FloatType> : public RandomImpl {
+ using CType = float;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float32(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<DoubleType> : public RandomImpl {
+ using CType = double;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float64(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<Decimal128Type> : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed,
+ std::shared_ptr<DataType> type = decimal128(18, 5))
+ : RandomImpl(seed, std::move(type)) {}
+};
+
+template <typename ArrowType>
+class RandomRange : public RandomImpl {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ public:
+ explicit RandomRange(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, int range, double null_prob)
{
+ CType min = std::numeric_limits<CType>::min();
+ CType max = min + range;
+ if (sizeof(CType) < 4 && (range + min) >
std::numeric_limits<CType>::max()) {
+ max = std::numeric_limits<CType>::max();
+ }
+ return generator_.Numeric<ArrowType>(count, min, max, null_prob);
+ }
+};
+
+TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes);
+
+TYPED_TEST(TestSelectKRandom, RandomValues) {
+ Random<TypeParam> rand(0x61549225);
+ int length = 100;
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Try n from 0 to out of bound
+ for (int n = 0; n <= length; ++n) {
+ auto array = rand.Generate(length, null_probability);
+ this->AssertTopKArray(array, n);
+ this->AssertBottomKArray(array, n);
+ }
+ }
+}
+
+template <SortOrder order>
+struct TestSelectKWithChunkedArray : public ::testing::Test {
+ TestSelectKWithChunkedArray()
+ : sizes_({0, 1, 2, 4, 16, 31, 1234}),
+ null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {}
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<ChunkedArray>& values, int64_t k,
+ const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, SelectK<order>(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, SelectK<order>(*(ChunkedArrayFromJSON(type,
values)), k));
+ PrettyPrint(**out, {}, &std::cerr);
Review comment:
Maybe move the ValidateOutput call here so you can't forget it?
##########
File path: cpp/src/arrow/compute/kernels/select_k_test.cc
##########
@@ -0,0 +1,714 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/array/array_decimal.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+#include "arrow/type_traits.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+
+namespace {
+
+// Convert arrow::Type to arrow::DataType. If arrow::Type isn't
+// parameter free, this returns an arrow::DataType with the default
+// parameter.
+template <typename ArrowType>
+enable_if_t<TypeTraits<ArrowType>::is_parameter_free,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return TypeTraits<ArrowType>::type_singleton();
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, TimestampType>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return timestamp(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time32Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time32(TimeUnit::MILLI);
+}
+
+template <typename ArrowType>
+enable_if_t<std::is_same<ArrowType, Time64Type>::value,
std::shared_ptr<DataType>>
+TypeToDataType() {
+ return time64(TimeUnit::NANO);
+}
+
+// ----------------------------------------------------------------------
+// Tests for SelectK
+
+template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+} // namespace
+
+template <typename ArrayType, SortOrder order>
+class SelectKComparator {
+ public:
+ template <typename Type>
+ bool operator()(const Type& lval, const Type& rval) {
+ if (is_floating_type<typename ArrayType::TypeClass>::value) {
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
+ }
+ if (order == SortOrder::Ascending) {
+ return lval <= rval;
+ } else {
+ return rval <= lval;
+ }
+ }
+};
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+
+template <SortOrder order>
+Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) {
+ if (order == SortOrder::Descending) {
+ return TopK(values, k);
+ } else {
+ return BottomK(values, k);
+ }
+}
+template <typename ArrowType>
+class TestSelectKBase : public TestBase {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+
+ protected:
+ void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder
order) {
+ ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_datum,
+ Take(array, sorted_indices,
TakeOptions::NoBoundsCheck()));
+ std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array();
+
+ const ArrayType& sorted_array =
*checked_pointer_cast<ArrayType>(sorted_array_out);
+
+ if (k < array.length()) {
+ for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) {
+ const auto lval = GetLogicalValue(select_k, i);
+ const auto rval = GetLogicalValue(sorted_array, i);
+ ASSERT_TRUE(lval == rval);
+ }
+ }
+ }
+ template <SortOrder order>
+ void AssertSelectKArray(const std::shared_ptr<Array> values, int n) {
+ std::shared_ptr<Array> select_k;
+ ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n));
+ ASSERT_EQ(select_k->data()->null_count, 0);
+ ValidateOutput(*select_k);
+ Validate(*checked_pointer_cast<ArrayType>(values), n,
+ *checked_pointer_cast<ArrayType>(select_k), order);
+ }
+
+ void AssertTopKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+ void AssertBottomKArray(const std::shared_ptr<Array> values, int n) {
+ AssertSelectKArray<SortOrder::Descending>(values, n);
+ }
+
+ void AssertSelectKJson(const std::string& values, int n) {
+ AssertTopKArray(ArrayFromJSON(GetType(), values), n);
+ AssertBottomKArray(ArrayFromJSON(GetType(), values), n);
+ }
+
+ virtual std::shared_ptr<DataType> GetType() = 0;
+};
+
+template <typename ArrowType>
+class TestSelectK : public TestSelectKBase<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> GetType() override { return
TypeToDataType<ArrowType>(); }
+};
+
+template <typename ArrowType>
+class TestSelectKForReal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForIntegral : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForBool : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types<BooleanType>);
+
+template <typename ArrowType>
+class TestSelectKForTemporal : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForDecimal : public TestSelectKBase<ArrowType> {
+ std::shared_ptr<DataType> GetType() override {
+ return std::make_shared<ArrowType>(5, 2);
+ }
+};
+TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes);
+
+template <typename ArrowType>
+class TestSelectKForStrings : public TestSelectK<ArrowType> {};
+TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types<StringType>);
+
+TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) {
+ auto input = ArrayFromJSON(this->GetType(), "[null, 1, 3.3, null, 2, 5.3]");
+ ASSERT_RAISES(Invalid, CallFunction("top_k", {input}));
+}
+
+TYPED_TEST(TestSelectKForReal, Real) {
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5);
+ this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6);
+
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3);
+ this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3);
+ this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4);
+}
+
+TYPED_TEST(TestSelectKForIntegral, Integral) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+
+ this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5);
+}
+
+TYPED_TEST(TestSelectKForBool, Bool) {
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 0);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 2);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 5);
+ this->AssertSelectKJson("[null, false, true, null, false, true]", 6);
+}
+
+TYPED_TEST(TestSelectKForTemporal, Temporal) {
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5);
+ this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6);
+}
+
+TYPED_TEST(TestSelectKForDecimal, Decimal) {
+ const std::string values = R"(["123.45", null, "-123.45", "456.78",
"-456.78"])";
+ this->AssertSelectKJson(values, 0);
+ this->AssertSelectKJson(values, 2);
+ this->AssertSelectKJson(values, 4);
+ this->AssertSelectKJson(values, 5);
+}
+
+TYPED_TEST(TestSelectKForStrings, Strings) {
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 0);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 2);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 5);
+ this->AssertSelectKJson(R"(["testing", null, "nth", "for", null,
"strings"])", 6);
+}
+
+/*struct TestSelectKWithKeepParam : public ::testing::Test {
+ void Check(const std::shared_ptr<DataType>& type, const std::string& values,
int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type, const
std::shared_ptr<Array>& values,
+ int64_t k, const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TopK(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type, const std::string&
values,
+ int64_t k, std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, TopK(*ArrayFromJSON(type, values), k));
+ PrettyPrint(**out, {}, &std::cerr);
+ return Status::OK();
+ }
+};
+
+TEST_F(TestSelectKWithKeepParam, Integral) {
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 0, "[]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 2, "[5, 3]");
+ this->Check(int32(), "[null, 1, 3, null, 2, 5]", 5, "[5, 3, 2, 1]");
+ this->Check(int32(), "[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]", 5, "[4, 4, 4, 4, 3]");
+}*/
+
+template <typename ArrowType>
+class TestSelectKRandom : public TestSelectKBase<ArrowType> {
+ public:
+ std::shared_ptr<DataType> GetType() override {
+ EXPECT_TRUE(0) << "shouldn't be used";
+ return nullptr;
+ }
+};
+
+using SelectKableTypes =
+ ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type,
Int16Type,
+ Int32Type, Int64Type, FloatType, DoubleType,
Decimal128Type,
+ StringType>;
+
+class RandomImpl {
+ protected:
+ random::RandomArrayGenerator generator_;
+ std::shared_ptr<DataType> type_;
+
+ explicit RandomImpl(random::SeedType seed, std::shared_ptr<DataType> type)
+ : generator_(seed), type_(std::move(type)) {}
+
+ public:
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
+ return generator_.ArrayOf(type_, count, null_prob);
+ }
+};
+
+template <typename ArrowType>
+class Random : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+};
+
+template <>
+class Random<FloatType> : public RandomImpl {
+ using CType = float;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float32(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<DoubleType> : public RandomImpl {
+ using CType = double;
+
+ public:
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double
nan_prob = 0) {
+ return generator_.Float64(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob,
nan_prob);
+ }
+};
+
+template <>
+class Random<Decimal128Type> : public RandomImpl {
+ public:
+ explicit Random(random::SeedType seed,
+ std::shared_ptr<DataType> type = decimal128(18, 5))
+ : RandomImpl(seed, std::move(type)) {}
+};
+
+template <typename ArrowType>
+class RandomRange : public RandomImpl {
+ using CType = typename TypeTraits<ArrowType>::CType;
+
+ public:
+ explicit RandomRange(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
+
+ std::shared_ptr<Array> Generate(uint64_t count, int range, double null_prob)
{
+ CType min = std::numeric_limits<CType>::min();
+ CType max = min + range;
+ if (sizeof(CType) < 4 && (range + min) >
std::numeric_limits<CType>::max()) {
+ max = std::numeric_limits<CType>::max();
+ }
+ return generator_.Numeric<ArrowType>(count, min, max, null_prob);
+ }
+};
+
+TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes);
+
+TYPED_TEST(TestSelectKRandom, RandomValues) {
+ Random<TypeParam> rand(0x61549225);
+ int length = 100;
+ for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
+ // Try n from 0 to out of bound
+ for (int n = 0; n <= length; ++n) {
+ auto array = rand.Generate(length, null_probability);
+ this->AssertTopKArray(array, n);
+ this->AssertBottomKArray(array, n);
+ }
+ }
+}
+
+template <SortOrder order>
+struct TestSelectKWithChunkedArray : public ::testing::Test {
+ TestSelectKWithChunkedArray()
+ : sizes_({0, 1, 2, 4, 16, 31, 1234}),
+ null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {}
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ const std::string& expected) {
+ std::shared_ptr<Array> actual;
+ ASSERT_OK(this->DoSelectK(type, values, k, &actual));
+ ValidateOutput(actual);
+
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ void Check(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<ChunkedArray>& values, int64_t k,
+ const std::string& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, SelectK<order>(*values, k));
+ ValidateOutput(actual);
+ ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual);
+ }
+
+ Status DoSelectK(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, int64_t k,
+ std::shared_ptr<Array>* out) {
+ ARROW_ASSIGN_OR_RAISE(*out, SelectK<order>(*(ChunkedArrayFromJSON(type,
values)), k));
+ PrettyPrint(**out, {}, &std::cerr);
+ return Status::OK();
+ }
+ std::vector<int32_t> sizes_;
+ std::vector<double> null_probabilities_;
+};
+
+struct TestTopKWithChunkedArray
Review comment:
You can make the 'base' overload of the test helpers take Datum and then
use ChunkedArrayFromJSON in the tests themselves, so that you don't have to
write too many overloads.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -1778,6 +1784,711 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// ----------------------------------------------------------------------
+// TopK/BottomK implementations
+
+using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>;
+const auto kDefaultTopKOptions = SelectKOptions::TopKDefault();
+const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault();
+
+const FunctionDoc top_k_doc(
+ "Return the indices that would partition an array array, record batch or
table\n"
+ "around a pivot",
+ ("@TODO"), {"input", "k"}, "PartitionNthOptions");
Review comment:
Is `k` really a parameter? Shouldn't these be unary functions?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]