kou commented on a change in pull request #8612:
URL: https://github.com/apache/arrow/pull/8612#discussion_r526619205
##########
File path: cpp/src/arrow/compute/kernels/vector_sort_test.cc
##########
@@ -364,32 +400,264 @@ TYPED_TEST(TestSortToIndicesKernelRandomCount,
SortRandomValuesCount) {
int range = 2000;
for (int test = 0; test < times; test++) {
for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
- auto array = rand.Generate(length, range, null_probability);
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
SortToIndices(*array));
- ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
- *checked_pointer_cast<UInt64Array>(offsets));
+ for (auto order : {SortOrder::ASCENDING, SortOrder::DESCENDING}) {
+ auto array = rand.Generate(length, range, null_probability);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
SortIndices(*array, order));
+ ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
+ *checked_pointer_cast<UInt64Array>(offsets),
order);
+ }
}
}
}
// Long array with big value range: std::stable_sort
-TYPED_TEST_SUITE(TestSortToIndicesKernelRandomCompare, IntegralArrowTypes);
+TYPED_TEST_SUITE(TestArraySortIndicesKernelRandomCompare, IntegralArrowTypes);
-TYPED_TEST(TestSortToIndicesKernelRandomCompare, SortRandomValuesCompare) {
+TYPED_TEST(TestArraySortIndicesKernelRandomCompare, SortRandomValuesCompare) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
Random<TypeParam> rand(0x5487657);
int times = 5;
int length = 4000;
for (int test = 0; test < times; test++) {
for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) {
- auto array = rand.Generate(length, null_probability);
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
SortToIndices(*array));
- ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
- *checked_pointer_cast<UInt64Array>(offsets));
+ for (auto order : {SortOrder::ASCENDING, SortOrder::DESCENDING}) {
+ auto array = rand.Generate(length, null_probability);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets,
SortIndices(*array, order));
+ ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
+ *checked_pointer_cast<UInt64Array>(offsets),
order);
+ }
}
}
}
+class TestTableSortIndices : public ::testing::Test {
+ protected:
+ void AssertSortIndices(const std::shared_ptr<Table> table, const
SortOptions& options,
+ const std::shared_ptr<Array> expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*table, options));
+ AssertArraysEqual(*expected, *actual);
+ }
+
+ void AssertSortIndices(const std::shared_ptr<Table> table, const
SortOptions& options,
+ const std::string expected) {
+ AssertSortIndices(table, options, ArrayFromJSON(uint64(), expected));
+ }
+};
+
+TEST_F(TestTableSortIndices, SortNull) {
+ auto table = TableFromJSON(schema({
+ {field("a", uint8())},
+ {field("b", uint8())},
+ }),
+ {"["
+ "{\"a\": null, \"b\": 5},"
+ "{\"a\": 1, \"b\": 3},"
+ "{\"a\": 3, \"b\": null},"
+ "{\"a\": null, \"b\": null},"
+ "{\"a\": 2, \"b\": 5},"
+ "{\"a\": 1, \"b\": 5}"
+ "]"});
+ SortOptions options(
+ {SortKey("a", SortOrder::ASCENDING), SortKey("b",
SortOrder::DESCENDING)});
+ this->AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]");
+}
+
+TEST_F(TestTableSortIndices, SortNaN) {
+ auto table = TableFromJSON(schema({
+ {field("a", float32())},
+ {field("b", float32())},
+ }),
+ {"["
+ "{\"a\": null, \"b\": 5},"
+ "{\"a\": 1, \"b\": 3},"
+ "{\"a\": 3, \"b\": null},"
+ "{\"a\": null, \"b\": null},"
+ "{\"a\": NaN, \"b\": null},"
+ "{\"a\": NaN, \"b\": 5},"
+ "{\"a\": NaN, \"b\": NaN},"
+ "{\"a\": 1, \"b\": 5}"
+ "]"});
+ SortOptions options(
+ {SortKey("a", SortOrder::ASCENDING), SortKey("b",
SortOrder::DESCENDING)});
+ this->AssertSortIndices(table, options, "[7, 1, 2, 5, 6, 4, 0, 3]");
+}
+
+using RandomParam = std::tuple<std::string, double>;
+class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
+ class Comparator : public TypeVisitor {
+ public:
+ bool operator()(const Table& table, const SortOptions& options, uint64_t
lhs,
+ uint64_t rhs) {
+ lhs_ = lhs;
+ rhs_ = rhs;
+ for (const auto& sort_key : options.sort_keys) {
+ auto chunked_array = table.GetColumnByName(sort_key.name);
+ lhs_array_ = findTargetArray(chunked_array, lhs, lhs_index_);
+ rhs_array_ = findTargetArray(chunked_array, rhs, rhs_index_);
+ if (rhs_array_->IsNull(rhs_index_) && lhs_array_->IsNull(lhs_index_))
continue;
+ if (rhs_array_->IsNull(rhs_index_)) return true;
+ if (lhs_array_->IsNull(lhs_index_)) return false;
+ status_ = lhs_array_->type()->Accept(this);
+ if (compared_ == 0) continue;
+ if (sort_key.order == SortOrder::ASCENDING) {
+ return compared_ < 0;
+ } else {
+ return compared_ > 0;
+ }
+ }
+ return lhs < rhs;
+ }
+
+ Status status() const { return status_; }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE##Type& type) override { \
+ compared_ = CompareType<TYPE##Type>(); \
+ return Status::OK(); \
+ }
+
+ VISIT(Int8)
+ VISIT(Int16)
+ VISIT(Int32)
+ VISIT(Int64)
+ VISIT(UInt8)
+ VISIT(UInt16)
+ VISIT(UInt32)
+ VISIT(UInt64)
+ VISIT(Float)
+ VISIT(Double)
+ VISIT(String)
+
+#undef VISIT
+
+ private:
+ std::shared_ptr<Array> findTargetArray(std::shared_ptr<ChunkedArray>
chunked_array,
Review comment:
Changed.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]