lidavidm commented on a change in pull request #10134:
URL: https://github.com/apache/arrow/pull/10134#discussion_r618709167
##########
File path: cpp/src/arrow/dataset/test_util.h
##########
@@ -303,6 +304,316 @@ template <typename P>
class DatasetFixtureMixinWithParam : public DatasetFixtureMixin,
public ::testing::WithParamInterface<P>
{};
+struct TestScannerParams {
+ bool use_async;
+ bool use_threads;
+ int num_child_datasets;
+ int num_batches;
+ int items_per_batch;
+
+ int64_t total_batches() const { return num_child_datasets * num_batches; }
+
+ int64_t expected_rows() const { return total_batches() * items_per_batch; }
+
+ std::string ToString() const {
+ // GTest requires this to be alphanumeric
+ std::stringstream ss;
+ ss << (use_async ? "Async" : "Sync") << (use_threads ? "Threaded" :
"Serial")
+ << num_child_datasets << "d" << num_batches << "b" << items_per_batch
<< "r";
+ return ss.str();
+ }
+
+ static std::string ToTestNameString(
+ const ::testing::TestParamInfo<TestScannerParams>& info) {
+ return std::to_string(info.index) + info.param.ToString();
+ }
+
+ static std::vector<TestScannerParams> Values() {
+ std::vector<TestScannerParams> values;
+ for (int sync = 0; sync < 2; sync++) {
+ for (int use_threads = 0; use_threads < 2; use_threads++) {
+ values.push_back(
+ {static_cast<bool>(sync), static_cast<bool>(use_threads), 1, 1,
1024});
+ values.push_back(
+ {static_cast<bool>(sync), static_cast<bool>(use_threads), 2, 16,
1024});
+ }
+ }
+ return values;
+ }
+};
+
+std::ostream& operator<<(std::ostream& out, const TestScannerParams& params) {
+ out << (params.use_async ? "async-" : "sync-")
+ << (params.use_threads ? "threaded-" : "serial-") <<
params.num_child_datasets
+ << "d-" << params.num_batches << "b-" << params.items_per_batch << "i";
+ return out;
+}
+
+class FileFormatWriterMixin {
+ virtual std::shared_ptr<Buffer> Write(RecordBatchReader* reader) = 0;
+ virtual std::shared_ptr<Buffer> Write(const Table& table) = 0;
+};
+
+/// WriterMixin should be a class with these static methods:
+/// std::shared_ptr<Buffer> Write(RecordBatchReader* reader);
+template <typename WriterMixin>
+class FileFormatFixtureMixin : public ::testing::Test {
+ public:
+ constexpr static int64_t kBatchSize = 1UL << 12;
+ constexpr static int64_t kBatchRepetitions = 1 << 5;
+
+ int64_t expected_batches() const { return kBatchRepetitions; }
+ int64_t expected_rows() const { return kBatchSize * kBatchRepetitions; }
+
+ std::shared_ptr<FileSource> GetFileSource(RecordBatchReader* reader) {
+ auto buffer = WriterMixin::Write(reader);
+ return std::make_shared<FileSource>(std::move(buffer));
+ }
+
+ virtual std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+ std::shared_ptr<Schema> schema) {
+ return MakeGeneratedRecordBatch(schema, kBatchSize, kBatchRepetitions);
+ }
+
+ Result<std::shared_ptr<io::BufferOutputStream>> GetFileSink() {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ResizableBuffer> buffer,
+ AllocateResizableBuffer(0));
+ return std::make_shared<io::BufferOutputStream>(buffer);
+ }
+
+ void SetSchema(std::vector<std::shared_ptr<Field>> fields) {
+ opts_ = std::make_shared<ScanOptions>();
+ opts_->dataset_schema = schema(std::move(fields));
+ ASSERT_OK(SetProjection(opts_.get(),
opts_->dataset_schema->field_names()));
+ }
+
+ void SetFilter(Expression filter) {
+ ASSERT_OK_AND_ASSIGN(opts_->filter, filter.Bind(*opts_->dataset_schema));
+ }
+
+ void Project(std::vector<std::string> names) {
+ ASSERT_OK(SetProjection(opts_.get(), std::move(names)));
+ }
+
+ // Shared test cases
+ void TestOpenFailureWithRelevantError(FileFormat* format, StatusCode code) {
+ std::shared_ptr<Buffer> buf =
std::make_shared<Buffer>(util::string_view(""));
+ auto result = format->Inspect(FileSource(buf));
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(code, result.status().code());
+ EXPECT_THAT(result.status().ToString(), testing::HasSubstr("<Buffer>"));
+
+ constexpr auto file_name = "herp/derp";
+ ASSERT_OK_AND_ASSIGN(
+ auto fs, fs::internal::MockFileSystem::Make(fs::kNoTime,
{fs::File(file_name)}));
+ result = format->Inspect({file_name, fs});
+ EXPECT_FALSE(result.ok());
+ EXPECT_EQ(code, result.status().code());
+ EXPECT_THAT(result.status().ToString(), testing::HasSubstr(file_name));
+ }
+ void TestInspect(FileFormat* format) {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ ASSERT_OK_AND_ASSIGN(auto actual, format->Inspect(*source.get()));
+ AssertSchemaEqual(*actual, *reader->schema(), /*check_metadata=*/false);
+ }
+ void TestIsSupported(FileFormat* format) {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = GetFileSource(reader.get());
+
+ bool supported = false;
+
+ std::shared_ptr<Buffer> buf =
std::make_shared<Buffer>(util::string_view(""));
+ ASSERT_OK_AND_ASSIGN(supported, format->IsSupported(FileSource(buf)));
+ ASSERT_EQ(supported, false);
+
+ buf = std::make_shared<Buffer>(util::string_view("corrupted"));
+ ASSERT_OK_AND_ASSIGN(supported, format->IsSupported(FileSource(buf)));
+ ASSERT_EQ(supported, false);
+
+ ASSERT_OK_AND_ASSIGN(supported, format->IsSupported(*source));
+ EXPECT_EQ(supported, true);
+ }
+ std::shared_ptr<Buffer> TestWrite(FileFormat* format,
std::shared_ptr<Schema> schema,
+ std::shared_ptr<FileWriteOptions> options
= nullptr) {
+ SetSchema(schema->fields());
+ EXPECT_OK_AND_ASSIGN(auto sink, GetFileSink());
+
+ if (!options) options = format->DefaultWriteOptions();
+ EXPECT_OK_AND_ASSIGN(auto writer, format->MakeWriter(sink, schema,
options));
+ ARROW_EXPECT_OK(writer->Write(GetRecordBatchReader(schema).get()));
+ ARROW_EXPECT_OK(writer->Finish());
+ EXPECT_OK_AND_ASSIGN(auto written, sink->Finish());
+ return written;
+ }
+
+ protected:
+ std::shared_ptr<ScanOptions> opts_ = std::make_shared<ScanOptions>();
+};
+
+template <typename Writer>
+class FileFormatScanMixin : public FileFormatFixtureMixin<Writer>,
+ public
::testing::WithParamInterface<TestScannerParams> {
+ public:
+ int64_t expected_batches() const { return GetParam().total_batches(); }
+ int64_t expected_rows() const { return GetParam().expected_rows(); }
+
+ std::shared_ptr<RecordBatchReader> GetRecordBatchReader(
+ std::shared_ptr<Schema> schema) override {
+ return MakeGeneratedRecordBatch(schema, GetParam().items_per_batch,
+ GetParam().total_batches());
+ }
+
+ // Scan the fragment through the scanner.
+ RecordBatchIterator Batches(std::shared_ptr<Fragment> fragment) {
+ EXPECT_OK_AND_ASSIGN(auto schema, fragment->ReadPhysicalSchema());
+ auto dataset = std::make_shared<FragmentDataset>(schema,
FragmentVector{fragment});
+ ScannerBuilder builder(dataset, opts_);
+ ARROW_EXPECT_OK(builder.UseAsync(GetParam().use_async));
+ ARROW_EXPECT_OK(builder.UseThreads(GetParam().use_threads));
+ EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ EXPECT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ return MakeMapIterator([](TaggedRecordBatch tagged) { return
tagged.record_batch; },
+ std::move(batch_it));
+ }
+
+ // Scan the fragment directly, without using the scanner.
+ RecordBatchIterator PhysicalBatches(std::shared_ptr<Fragment> fragment) {
+ if (GetParam().use_async) {
+ EXPECT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(opts_));
+ EXPECT_OK_AND_ASSIGN(auto batch_it,
MakeGeneratorIterator(std::move(batch_gen)));
+ return batch_it;
+ }
+ EXPECT_OK_AND_ASSIGN(auto scan_task_it, fragment->Scan(opts_));
+ return MakeFlattenIterator(MakeMaybeMapIterator(
+ [](std::shared_ptr<ScanTask> scan_task) { return scan_task->Execute();
},
+ std::move(scan_task_it)));
+ }
+
+ // Shared test cases
+ void TestScan(FileFormat* format) {
+ auto reader = GetRecordBatchReader(schema({field("f64", float64())}));
+ auto source = this->GetFileSource(reader.get());
+
+ this->SetSchema(reader->schema()->fields());
+ ASSERT_OK_AND_ASSIGN(auto fragment, format->MakeFragment(*source));
+
+ int64_t row_count = 0;
+ for (auto maybe_batch : Batches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ }
+ ASSERT_EQ(row_count, GetParam().expected_rows());
+ }
+
+ void TestScanProjected(FileFormat* format) {
Review comment:
This is testing whether each format will look at the projection in the
ScanOptions and only return the necessary columns to fulfill the projection +
filter later on. Notably, CSV didn't properly do that before.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]