westonpace commented on a change in pull request #10134:
URL: https://github.com/apache/arrow/pull/10134#discussion_r618738629
##########
File path: cpp/src/arrow/dataset/test_util.h
##########
@@ -303,6 +304,316 @@ template <typename P>
class DatasetFixtureMixinWithParam : public DatasetFixtureMixin,
public ::testing::WithParamInterface<P>
{};
+struct TestScannerParams {
+ bool use_async;
+ bool use_threads;
+ int num_child_datasets;
+ int num_batches;
+ int items_per_batch;
+
+ int64_t total_batches() const { return num_child_datasets * num_batches; }
+
+ int64_t expected_rows() const { return total_batches() * items_per_batch; }
+
+ std::string ToString() const {
+ // GTest requires this to be alphanumeric
+ std::stringstream ss;
+ ss << (use_async ? "Async" : "Sync") << (use_threads ? "Threaded" :
"Serial")
+ << num_child_datasets << "d" << num_batches << "b" << items_per_batch
<< "r";
+ return ss.str();
+ }
+
+ static std::string ToTestNameString(
+ const ::testing::TestParamInfo<TestScannerParams>& info) {
+ return std::to_string(info.index) + info.param.ToString();
+ }
+
+ static std::vector<TestScannerParams> Values() {
+ std::vector<TestScannerParams> values;
+ for (int sync = 0; sync < 2; sync++) {
+ for (int use_threads = 0; use_threads < 2; use_threads++) {
+ values.push_back(
+ {static_cast<bool>(sync), static_cast<bool>(use_threads), 1, 1,
1024});
+ values.push_back(
+ {static_cast<bool>(sync), static_cast<bool>(use_threads), 2, 16,
1024});
+ }
+ }
+ return values;
+ }
+};
+
+std::ostream& operator<<(std::ostream& out, const TestScannerParams& params) {
+ out << (params.use_async ? "async-" : "sync-")
+ << (params.use_threads ? "threaded-" : "serial-") <<
params.num_child_datasets
+ << "d-" << params.num_batches << "b-" << params.items_per_batch << "i";
+ return out;
+}
+
+class FileFormatWriterMixin {
+ virtual std::shared_ptr<Buffer> Write(RecordBatchReader* reader) = 0;
+ virtual std::shared_ptr<Buffer> Write(const Table& table) = 0;
+};
+
+/// WriterMixin should be a class with these static methods:
+/// std::shared_ptr<Buffer> Write(RecordBatchReader* reader);
+template <typename WriterMixin>
+class FileFormatFixtureMixin : public ::testing::Test {
+ public:
+ constexpr static int64_t kBatchSize = 1UL << 12;
+ constexpr static int64_t kBatchRepetitions = 1 << 5;
+
+ int64_t expected_batches() const { return kBatchRepetitions; }
+ int64_t expected_rows() const { return kBatchSize * kBatchRepetitions; }
+
+ std::shared_ptr<FileSource> GetFileSource(RecordBatchReader* reader) {
+ auto buffer = WriterMixin::Write(reader);
+ return std::make_shared<FileSource>(std::move(buffer));
+ }
+
+ virtual std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+ std::shared_ptr<Schema> schema) {
+ return MakeGeneratedRecordBatch(schema, kBatchSize, kBatchRepetitions);
+ }
+
+ Result<std::shared_ptr<io::BufferOutputStream>> GetFileSink() {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> buffer,
+ AllocateResizableBuffer(0));
+ return std::make_shared<io::BufferOutputStream>(buffer);
+ }
+
+ void SetSchema(std::vector<std::shared_ptr<Field>> fields) {
+ opts_ = std::make_shared<ScanOptions>();
+ opts_->dataset_schema = schema(std::move(fields));
+ ASSERT_OK(SetProjection(opts_.get(),
opts_->dataset_schema->field_names()));
+ }
+
+ void SetFilter(Expression filter) {
+ ASSERT_OK_AND_ASSIGN(opts_->filter, filter.Bind(*opts_->dataset_schema));
+ }
+
+ void Project(std::vector<std::string> names) {
+ ASSERT_OK(SetProjection(opts_.get(), std::move(names)));
+ }
+
+ // Shared test cases
+ void TestOpenFailureWithRelevantError(FileFormat* format, StatusCode code) {
+ std::shared_ptr<Buffer> buf =
std::make_shared<Buffer>(util::string_view(""));
+ auto result = format->Inspect(FileSource(buf));
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(code, result.status().code());
+ EXPECT_THAT(result.status().ToString(), testing::HasSubstr("<Buffer>"));
+
+ constexpr auto file_name = "herp/derp";
+ ASSERT_OK_AND_ASSIGN(
+ auto fs, fs::internal::MockFileSystem::Make(fs::kNoTime,
{fs::File(file_name)}));
+ result = format->Inspect({file_name, fs});
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(code, result.status().code());
+ EXPECT_THAT(result.status().ToString(), testing::HasSubstr(file_name));
+ }
+ void TestInspect(FileFormat* format) {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ ASSERT_OK_AND_ASSIGN(auto actual, format->Inspect(*source.get()));
+ AssertSchemaEqual(*actual, *reader->schema(), /*check_metadata=*/false);
+ }
+ void TestIsSupported(FileFormat* format) {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ bool supported = false;
+
+ std::shared_ptr<Buffer> buf =
std::make_shared<Buffer>(util::string_view(""));
+ ASSERT_OK_AND_ASSIGN(supported, format->IsSupported(FileSource(buf)));
+ ASSERT_EQ(supported, false);
+
+ buf = std::make_shared<Buffer>(util::string_view("corrupted"));
+ ASSERT_OK_AND_ASSIGN(supported, format->IsSupported(FileSource(buf)));
+ ASSERT_EQ(supported, false);
+
+ ASSERT_OK_AND_ASSIGN(supported, format->IsSupported(*source));
+ EXPECT_EQ(supported, true);
+ }
+ std::shared_ptr<Buffer> TestWrite(FileFormat* format,
std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options
= nullptr) {
+ SetSchema(schema->fields());
+ EXPECT_OK_AND_ASSIGN(auto sink, GetFileSink());
+
+ if (!options) options = format->DefaultWriteOptions();
+ EXPECT_OK_AND_ASSIGN(auto writer, format->MakeWriter(sink, schema,
options));
+ ARROW_EXPECT_OK(writer->Write(GetRecordBatchReader(schema).get()));
+ ARROW_EXPECT_OK(writer->Finish());
+ EXPECT_OK_AND_ASSIGN(auto written, sink->Finish());
+ return written;
+ }
+
+ protected:
+ std::shared_ptr<ScanOptions> opts_ = std::make_shared<ScanOptions>();
+};
+
+template <typename Writer>
+class FileFormatScanMixin : public FileFormatFixtureMixin<Writer>,
+ public
::testing::WithParamInterface<TestScannerParams> {
+ public:
+ int64_t expected_batches() const { return GetParam().total_batches(); }
+ int64_t expected_rows() const { return GetParam().expected_rows(); }
+
+ std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+ std::shared_ptr<Schema> schema) override {
+ return MakeGeneratedRecordBatch(schema, GetParam().items_per_batch,
+ GetParam().total_batches());
+ }
+
+ // Scan the fragment through the scanner.
+ RecordBatchIterator Batches(std::shared_ptr<Fragment> fragment) {
Review comment:
I could be easily convinced either way. I think it's also good to
detect potential interaction bugs. I suppose the scanner could do various
things with scan options (like the projection issue you mentioned) that aren't
so easily exhaustively tested in formats.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]