This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c4a07e  ARROW-12487: [C++][Dataset] Fix ScanBatches() hanging
7c4a07e is described below

commit 7c4a07eeddf885189f7a6ad01ba9f98945bfa022
Author: David Li <[email protected]>
AuthorDate: Wed Apr 21 11:58:50 2021 +0900

    ARROW-12487: [C++][Dataset] Fix ScanBatches() hanging
    
    Errors weren't being handled in all paths.
    
    Closes #10115 from lidavidm/arrow-12487
    
    Authored-by: David Li <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 cpp/src/arrow/dataset/scanner.cc      |  17 +++++-
 cpp/src/arrow/dataset/scanner_test.cc | 103 ++++++++++++++++++++++++++++------
 2 files changed, 100 insertions(+), 20 deletions(-)

diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 43c0247..aa95c47 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -160,6 +160,19 @@ struct ScanBatchesState : public 
std::enable_shared_from_this<ScanBatchesState>
     ready.notify_one();
   }
 
+  template <typename T>
+  Result<T> PushError(Result<T>&& result, size_t task_index) {
+    if (!result.ok()) {
+      {
+        std::lock_guard<std::mutex> lock(mutex);
+        task_drained[task_index] = true;
+        iteration_error = result.status();
+      }
+      ready.notify_one();
+    }
+    return std::move(result);
+  }
+
   Status Finish(size_t task_index) {
     {
       std::lock_guard<std::mutex> lock(mutex);
@@ -190,9 +203,9 @@ struct ScanBatchesState : public 
std::enable_shared_from_this<ScanBatchesState>
 
     lock.unlock();
     task_group->Append([state, id, scan_task]() {
-      ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
+      ARROW_ASSIGN_OR_RAISE(auto batch_it, 
state->PushError(scan_task->Execute(), id));
       for (auto maybe_batch : batch_it) {
-        ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
+        ARROW_ASSIGN_OR_RAISE(auto batch, 
state->PushError(std::move(maybe_batch), id));
         state->Push(TaggedRecordBatch{std::move(batch), 
scan_task->fragment()}, id);
       }
       return state->Finish(id);
diff --git a/cpp/src/arrow/dataset/scanner_test.cc 
b/cpp/src/arrow/dataset/scanner_test.cc
index 552102b..27fcef1 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -321,31 +321,98 @@ class FailingFragment : public InMemoryFragment {
   }
 };
 
+class FailingExecuteScanTask : public InMemoryScanTask {
+ public:
+  using InMemoryScanTask::InMemoryScanTask;
+
+  Result<RecordBatchIterator> Execute() override {
+    return Status::Invalid("Oh no, we failed!");
+  }
+};
+
+class FailingIterationScanTask : public InMemoryScanTask {
+ public:
+  using InMemoryScanTask::InMemoryScanTask;
+
+  Result<RecordBatchIterator> Execute() override {
+    int index = 0;
+    auto batches = record_batches_;
+    return MakeFunctionIterator(
+        [index, batches]() mutable -> Result<std::shared_ptr<RecordBatch>> {
+          if (index < 1) {
+            return batches[index++];
+          }
+          return Status::Invalid("Oh no, we failed!");
+        });
+  }
+};
+
+template <typename T>
+class FailingScanTaskFragment : public InMemoryFragment {
+ public:
+  using InMemoryFragment::InMemoryFragment;
+  Result<ScanTaskIterator> Scan(std::shared_ptr<ScanOptions> options) override 
{
+    auto self = shared_from_this();
+    ScanTaskVector scan_tasks{std::make_shared<T>(record_batches_, options, 
self)};
+    return MakeVectorIterator(std::move(scan_tasks));
+  }
+};
+
 TEST_P(TestScanner, ScanBatchesFailure) {
   SetSchema({field("i32", int32()), field("f64", float64())});
   auto batch = ConstantArrayGenerator::Zeroes(GetParam().items_per_batch, 
schema_);
   RecordBatchVector batches = {batch, batch, batch, batch};
+  // Note these tests are only for SyncScanner at the moment
 
-  ScannerBuilder builder(schema_, std::make_shared<FailingFragment>(batches), 
options_);
-  ASSERT_OK(builder.UseThreads(GetParam().use_threads));
-  ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
-
-  ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
-
-  int counter = 0;
-  while (true) {
-    // Make sure we get all batches that were yielded before the failing scan 
task
-    auto maybe_batch = batch_it.Next();
-    if (counter++ <= 16) {
-      ASSERT_OK_AND_ASSIGN(auto scanned_batch, maybe_batch);
-      AssertBatchesEqual(*batch, *scanned_batch.record_batch);
-      ASSERT_NE(nullptr, scanned_batch.fragment);
-    } else {
-      EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we 
failed!"),
-                                      maybe_batch);
-      break;
+  // Case 1: failure when getting next scan task
+  {
+    ScannerBuilder builder(schema_, 
std::make_shared<FailingFragment>(batches), options_);
+    ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+    ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+    ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+
+    int counter = 0;
+    while (true) {
+      // Make sure we get all batches that were yielded before the failing 
scan task
+      auto maybe_batch = batch_it.Next();
+      if (counter++ <= 16) {
+        ASSERT_OK_AND_ASSIGN(auto scanned_batch, maybe_batch);
+        AssertBatchesEqual(*batch, *scanned_batch.record_batch);
+        ASSERT_NE(nullptr, scanned_batch.fragment);
+      } else {
+        EXPECT_RAISES_WITH_MESSAGE_THAT(
+            Invalid, ::testing::HasSubstr("Oh no, we failed!"), maybe_batch);
+        break;
+      }
     }
   }
+
+  // Case 2: failure when calling ScanTask::Execute
+  {
+    ScannerBuilder builder(
+        schema_,
+        
std::make_shared<FailingScanTaskFragment<FailingExecuteScanTask>>(batches),
+        options_);
+    ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+    ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+    ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we 
failed!"),
+                                    batch_it.Next());
+  }
+
+  // Case 3: failure when calling RecordBatchIterator::Next
+  {
+    ScannerBuilder builder(
+        schema_,
+        
std::make_shared<FailingScanTaskFragment<FailingIterationScanTask>>(batches),
+        options_);
+    ASSERT_OK(builder.UseThreads(GetParam().use_threads));
+    ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
+    ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
+    ASSERT_OK(batch_it.Next());
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("Oh no, we 
failed!"),
+                                    batch_it.Next());
+  }
 }
 
 TEST_P(TestScanner, Head) {

Reply via email to