This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 7c4a07e ARROW-12487: [C++][Dataset] Fix ScanBatches() hanging
7c4a07e is described below
commit 7c4a07eeddf885189f7a6ad01ba9f98945bfa022
Author: David Li <[email protected]>
AuthorDate: Wed Apr 21 11:58:50 2021 +0900
ARROW-12487: [C++][Dataset] Fix ScanBatches() hanging
Errors weren't being handled in all paths.
Closes #10115 from lidavidm/arrow-12487
Authored-by: David Li <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
cpp/src/arrow/dataset/scanner.cc | 17 +++++-
cpp/src/arrow/dataset/scanner_test.cc | 103 ++++++++++++++++++++++++++++------
2 files changed, 100 insertions(+), 20 deletions(-)
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 43c0247..aa95c47 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -160,6 +160,19 @@ struct ScanBatchesState : public
std::enable_shared_from_this<ScanBatchesState>
ready.notify_one();
}
+ template <typename T>
+ Result<T> PushError(Result<T>&& result, size_t task_index) {
+ if (!result.ok()) {
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ task_drained[task_index] = true;
+ iteration_error = result.status();
+ }
+ ready.notify_one();
+ }
+ return std::move(result);
+ }
+
Status Finish(size_t task_index) {
{
std::lock_guard<std::mutex> lock(mutex);
@@ -190,9 +203,9 @@ struct ScanBatchesState : public
std::enable_shared_from_this<ScanBatchesState>
lock.unlock();
task_group->Append([state, id, scan_task]() {
- ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
+ ARROW_ASSIGN_OR_RAISE(auto batch_it,
state->PushError(scan_task->Execute(), id));
for (auto maybe_batch : batch_it) {
- ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
+ ARROW_ASSIGN_OR_RAISE(auto batch,
state->PushError(std::move(maybe_batch), id));
state->Push(TaggedRecordBatch{std::move(batch),
scan_task->fragment()}, id);
}
return state->Finish(id);
diff --git a/cpp/src/arrow/dataset/scanner_test.cc
b/cpp/src/arrow/dataset/scanner_test.cc
index 552102b..27fcef1 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -321,31 +321,98 @@ class FailingFragment : public InMemoryFragment {
}
};
+class FailingExecuteScanTask : public InMemoryScanTask {
+ public:
+ using InMemoryScanTask::InMemoryScanTask;
+
+ Result<RecordBatchIterator> Execute() override {
+ return Status::Invalid("Oh no, we failed!");
+ }
+};
+
+class FailingIterationScanTask : public InMemoryScanTask {
+ public:
+ using InMemoryScanTask::InMemoryScanTask;
+
+ Result<RecordBatchIterator> Execute() override {
+ int index = 0;
+ auto batches = record_batches_;
+ return MakeFunctionIterator(
+ [index, batches]() mutable -> Result<std::shared_ptr<RecordBatch>> {
+ if (index < 1) {
+ return batches[index++];
+ }
+ return Status::Invalid("Oh no, we failed!");
+ });
+ }
+};
+
+template <typename T>
+class FailingScanTaskFragment : public InMemoryFragment {
+ public:
+ using InMemoryFragment::InMemoryFragment;
+ Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override
{
+ auto self = shared_from_this();
+ ScanTaskVector scan_tasks{std::make_shared<T>(record_batches_, options,
self)};
+ return MakeVectorIterator(std::move(scan_tasks));
+ }
+};
+
TEST_P(TestScanner, ScanBatchesFailure) {
SetSchema({field("i32", int32()), field("f64", float64())});
auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch,
schema_);
RecordBatchVector batches = {batch, batch, batch, batch};
+ // Note these tests are only for SyncScanner at the moment
- ScannerBuilder builder(schema_, std::make_shared<FailingFragment>(batches),
options_);
- ASSERT_OK(builder.UseThreads(GetParam().use_threads));
- ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
-
- ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
-
- int counter = 0;
- while (true) {
- // Make sure we get all batches that were yielded before the failing scan
task
- auto maybe_batch = batch_it.Next();
- if (counter++ <= 16) {
- ASSERT_OK_AND_ASSIGN(auto scanned_batch, maybe_batch);
- AssertBatchesEqual(*batch, *scanned_batch.record_batch);
- ASSERT_NE(nullptr, scanned_batch.fragment);
- } else {
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we
failed!"),
- maybe_batch);
- break;
+ // Case 1: failure when getting next scan task
+ {
+ ScannerBuilder builder(schema_,
std::make_shared<FailingFragment>(batches), options_);
+ ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+
+ int counter = 0;
+ while (true) {
+ // Make sure we get all batches that were yielded before the failing
scan task
+ auto maybe_batch = batch_it.Next();
+ if (counter++ <= 16) {
+ ASSERT_OK_AND_ASSIGN(auto scanned_batch, maybe_batch);
+ AssertBatchesEqual(*batch, *scanned_batch.record_batch);
+ ASSERT_NE(nullptr, scanned_batch.fragment);
+ } else {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Oh no, we failed!"), maybe_batch);
+ break;
+ }
}
}
+
+ // Case 2: failure when calling ScanTask::Execute
+ {
+ ScannerBuilder builder(
+ schema_,
+
std::make_shared<FailingScanTaskFragment<FailingExecuteScanTask>>(batches),
+ options_);
+ ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we
failed!"),
+ batch_it.Next());
+ }
+
+ // Case 3: failure when calling RecordBatchIterator::Next
+ {
+ ScannerBuilder builder(
+ schema_,
+
std::make_shared<FailingScanTaskFragment<FailingIterationScanTask>>(batches),
+ options_);
+ ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+ ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+ ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+ ASSERT_OK(batch_it.Next());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we
failed!"),
+ batch_it.Next());
+ }
}
TEST_P(TestScanner, Head) {