This is an automated email from the ASF dual-hosted git repository.
zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new d1bf85631f GH-45344: [C++][Testing] Generic `StepGenerator` (#45345)
d1bf85631f is described below
commit d1bf85631f979fdde000bd27cf850c2d465a3d8f
Author: Rossi Sun <[email protected]>
AuthorDate: Mon Feb 10 20:04:28 2025 +0800
GH-45344: [C++][Testing] Generic `StepGenerator` (#45345)
### Rationale for this change
#45344
### What changes are included in this PR?
Make the `StepGenerator` generic.
### Are these changes tested?
UT included.
### Are there any user-facing changes?
None.
* GitHub Issue: #45344
Lead-authored-by: Rossi Sun <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/acero/order_by_node_test.cc | 3 +-
cpp/src/arrow/acero/sorted_merge_node_test.cc | 3 +-
cpp/src/arrow/testing/CMakeLists.txt | 3 +-
cpp/src/arrow/testing/generator.cc | 41 ---------------
cpp/src/arrow/testing/generator.h | 44 ++++++++++++++--
cpp/src/arrow/testing/generator_test.cc | 74 +++++++++++++++++++++++++++
6 files changed, 119 insertions(+), 49 deletions(-)
diff --git a/cpp/src/arrow/acero/order_by_node_test.cc
b/cpp/src/arrow/acero/order_by_node_test.cc
index d77b0f3184..37e6862ed0 100644
--- a/cpp/src/arrow/acero/order_by_node_test.cc
+++ b/cpp/src/arrow/acero/order_by_node_test.cc
@@ -42,8 +42,7 @@ static constexpr int kRowsPerBatch = 4;
static constexpr int kNumBatches = 32;
std::shared_ptr<Table> TestTable() {
- return gen::Gen({{"up", gen::Step()},
- {"down", gen::Step(/*start=*/0, /*step=*/-1,
/*signed_int=*/true)}})
+ return gen::Gen({{"up", gen::Step()}, {"down", gen::Step(/*start=*/0,
/*step=*/-1)}})
->FailOnError()
->Table(kRowsPerBatch, kNumBatches);
}
diff --git a/cpp/src/arrow/acero/sorted_merge_node_test.cc
b/cpp/src/arrow/acero/sorted_merge_node_test.cc
index 55446d631d..82b630420c 100644
--- a/cpp/src/arrow/acero/sorted_merge_node_test.cc
+++ b/cpp/src/arrow/acero/sorted_merge_node_test.cc
@@ -36,8 +36,7 @@ namespace arrow::acero {
std::shared_ptr<Table> TestTable(int start, int step, int rows_per_batch,
int num_batches) {
- return gen::Gen({{"timestamp", gen::Step(start, step, /*signed_int=*/true)},
- {"str", gen::Random(utf8())}})
+ return gen::Gen({{"timestamp", gen::Step(start, step)}, {"str",
gen::Random(utf8())}})
->FailOnError()
->Table(rows_per_batch, num_batches);
}
diff --git a/cpp/src/arrow/testing/CMakeLists.txt
b/cpp/src/arrow/testing/CMakeLists.txt
index 6cf4b2d2b1..82db590e33 100644
--- a/cpp/src/arrow/testing/CMakeLists.txt
+++ b/cpp/src/arrow/testing/CMakeLists.txt
@@ -18,8 +18,9 @@
arrow_install_all_headers("arrow/testing")
if(ARROW_BUILD_TESTS)
- add_arrow_test(random_test)
+ add_arrow_test(generator_test)
add_arrow_test(gtest_util_test)
+ add_arrow_test(random_test)
if(ARROW_FILESYSTEM)
add_library(arrow_filesystem_example MODULE examplefs.cc)
diff --git a/cpp/src/arrow/testing/generator.cc
b/cpp/src/arrow/testing/generator.cc
index 5ea6a541e8..8715ecdeb5 100644
--- a/cpp/src/arrow/testing/generator.cc
+++ b/cpp/src/arrow/testing/generator.cc
@@ -26,7 +26,6 @@
#include "arrow/array.h"
#include "arrow/buffer.h"
-#include "arrow/builder.h"
#include "arrow/compute/exec.h"
#include "arrow/datum.h"
#include "arrow/record_batch.h"
@@ -220,42 +219,6 @@ class ConstantGenerator : public ArrayGenerator {
std::shared_ptr<Scalar> value_;
};
-class StepGenerator : public ArrayGenerator {
- public:
- StepGenerator(uint32_t start, uint32_t step, bool signed_int)
- : start_(start), step_(step), signed_int_(signed_int) {}
-
- template <typename BuilderType, typename CType>
- Result<std::shared_ptr<Array>> DoGenerate(int64_t num_rows) {
- BuilderType builder;
- ARROW_RETURN_NOT_OK(builder.Reserve(num_rows));
- CType val = start_;
- for (int64_t i = 0; i < num_rows; i++) {
- builder.UnsafeAppend(val);
- val += step_;
- }
- start_ = val;
- return builder.Finish();
- }
-
- Result<std::shared_ptr<Array>> Generate(int64_t num_rows) override {
- if (signed_int_) {
- return DoGenerate<Int32Builder, int32_t>(num_rows);
- } else {
- return DoGenerate<UInt32Builder, uint32_t>(num_rows);
- }
- }
-
- std::shared_ptr<DataType> type() const override {
- return signed_int_ ? int32() : uint32();
- }
-
- private:
- uint32_t start_;
- uint32_t step_;
- bool signed_int_;
-};
-
static constexpr random::SeedType kTestSeed = 42;
class RandomGenerator : public ArrayGenerator {
@@ -405,10 +368,6 @@ std::shared_ptr<ArrayGenerator>
Constant(std::shared_ptr<Scalar> value) {
return std::make_shared<ConstantGenerator>(std::move(value));
}
-std::shared_ptr<ArrayGenerator> Step(uint32_t start, uint32_t step, bool
signed_int) {
- return std::make_shared<StepGenerator>(start, step, signed_int);
-}
-
std::shared_ptr<ArrayGenerator> Random(std::shared_ptr<DataType> type) {
return std::make_shared<RandomGenerator>(std::move(type));
}
diff --git a/cpp/src/arrow/testing/generator.h
b/cpp/src/arrow/testing/generator.h
index 4ec8845864..e90c125a7d 100644
--- a/cpp/src/arrow/testing/generator.h
+++ b/cpp/src/arrow/testing/generator.h
@@ -23,6 +23,8 @@
#include <vector>
#include "arrow/array/array_base.h"
+#include "arrow/array/util.h"
+#include "arrow/buffer_builder.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/visibility.h"
@@ -301,12 +303,48 @@ ARROW_TESTING_EXPORT std::shared_ptr<DataGenerator> Gen(
/// make a generator that returns a constant value
ARROW_TESTING_EXPORT std::shared_ptr<ArrayGenerator> Constant(
std::shared_ptr<Scalar> value);
+
/// make a generator that returns an incrementing value
///
/// Note: overflow is not prevented standard unsigned integer overflow applies
-ARROW_TESTING_EXPORT std::shared_ptr<ArrayGenerator> Step(uint32_t start = 0,
- uint32_t step = 1,
- bool signed_int =
false);
+template <typename T = uint32_t>
+std::shared_ptr<ArrayGenerator> Step(T start = 0, T step = 1) {
+ class StepGenerator : public ArrayGenerator {
+ public:
+ // Use [[maybe_unused]] to avoid a compiler warning in Clang versions
before 15 that
+ // incorrectly reports 'unused type alias'.
+ using ArrowType [[maybe_unused]] = typename CTypeTraits<T>::ArrowType;
+ static_assert(is_number_type<ArrowType>::value,
+ "Step generator only supports numeric types");
+
+ StepGenerator(T start, T step) : start_(start), step_(step) {}
+
+ Result<std::shared_ptr<Array>> Generate(int64_t num_rows) override {
+ TypedBufferBuilder<T> builder;
+ ARROW_RETURN_NOT_OK(builder.Reserve(num_rows));
+ T val = start_;
+ for (int64_t i = 0; i < num_rows; i++) {
+ builder.UnsafeAppend(val);
+ val += step_;
+ }
+ start_ = val;
+ ARROW_ASSIGN_OR_RAISE(auto buf, builder.Finish());
+ return
MakeArray(ArrayData::Make(TypeTraits<ArrowType>::type_singleton(), num_rows,
+ {NULLPTR, std::move(buf)},
/*null_count=*/0));
+ }
+
+ std::shared_ptr<DataType> type() const override {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+
+ private:
+ T start_;
+ T step_;
+ };
+
+ return std::make_shared<StepGenerator>(start, step);
+}
+
/// make a generator that returns a random value
ARROW_TESTING_EXPORT std::shared_ptr<ArrayGenerator> Random(
std::shared_ptr<DataType> type);
diff --git a/cpp/src/arrow/testing/generator_test.cc
b/cpp/src/arrow/testing/generator_test.cc
new file mode 100644
index 0000000000..f4b38ee431
--- /dev/null
+++ b/cpp/src/arrow/testing/generator_test.cc
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/testing/generator.h"
+
+namespace arrow::gen {
+
+template <typename CType>
+void CheckStep(const Array& result, CType start, CType step, int64_t length) {
+ using ArrowType = typename CTypeTraits<CType>::ArrowType;
+
+ ASSERT_OK(result.ValidateFull());
+ ASSERT_EQ(result.type_id(), TypeTraits<ArrowType>::type_singleton()->id());
+ ASSERT_EQ(result.length(), length);
+ ASSERT_EQ(result.null_bitmap(), nullptr);
+ auto data = result.data()->GetValues<CType>(1);
+ CType current = start;
+ for (int64_t i = 0; i < length; ++i) {
+ ASSERT_EQ(data[i], current);
+ current += step;
+ }
+}
+
+TEST(StepTest, Default) {
+ for (auto length : {0, 1, 1024}) {
+ ARROW_SCOPED_TRACE("length=" + std::to_string(length));
+ ASSERT_OK_AND_ASSIGN(auto array, Step()->Generate(length));
+ CheckStep<uint32_t>(*array, 0, 1, length);
+ }
+}
+
+using NumericCTypes = ::testing::Types<int8_t, uint8_t, int16_t, uint16_t,
int32_t,
+ uint32_t, int64_t, uint64_t, float,
double>;
+
+template <typename CType>
+class TypedStepTest : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TypedStepTest, NumericCTypes);
+
+TYPED_TEST(TypedStepTest, Basic) {
+ for (auto length : {0, 1, 1024}) {
+ ARROW_SCOPED_TRACE("length=" + std::to_string(length));
+ for (TypeParam start :
+ {std::numeric_limits<TypeParam>::min(), static_cast<TypeParam>(0)}) {
+ ARROW_SCOPED_TRACE("start=" + std::to_string(start));
+ for (TypeParam step :
+ {static_cast<TypeParam>(0),
std::numeric_limits<TypeParam>::epsilon(),
+ static_cast<TypeParam>(std::numeric_limits<TypeParam>::max() /
+ (length + 1))}) {
+ ARROW_SCOPED_TRACE("step=" + std::to_string(step));
+ ASSERT_OK_AND_ASSIGN(auto array, Step(start, step)->Generate(length));
+ CheckStep(*array, start, step, length);
+ }
+ }
+ }
+}
+
+} // namespace arrow::gen