This is an automated email from the ASF dual-hosted git repository. raulcd pushed a commit to branch maint-12.0.0 in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 0bff85ca9efa1c07ef6d9ada2ed6b9d7704507e4 Author: Weston Pace <[email protected]> AuthorDate: Thu Apr 13 07:50:32 2023 -0700 GH-34539: [C++] Fix throttled scheduler to avoid stack overflow in dataset writer (#35075) ### Rationale for this change Fixes a bug in the throttled scheduler. ### What changes are included in this PR? The throttled scheduler will no longer recurse in the ContinueTasks loop if the continued task was immediately finished. ### Are these changes tested? Yes, I added a new stress test that exposed the stack overflow very reliably on a standard Linux system. ### Are there any user-facing changes? No. * Closes: #34539 Authored-by: Weston Pace <[email protected]> Signed-off-by: Weston Pace <[email protected]> --- cpp/src/arrow/util/async_util.cc | 26 +++++++++++++++++++------- cpp/src/arrow/util/async_util_test.cc | 24 ++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/util/async_util.cc b/cpp/src/arrow/util/async_util.cc index 0a59a462c9..55627eb43b 100644 --- a/cpp/src/arrow/util/async_util.cc +++ b/cpp/src/arrow/util/async_util.cc @@ -322,7 +322,7 @@ class ThrottledAsyncTaskSchedulerImpl return true; } else { lk.unlock(); - return SubmitTask(std::move(task), latched_cost); + return SubmitTask(std::move(task), latched_cost, /*in_continue=*/false); } } @@ -331,18 +331,30 @@ class ThrottledAsyncTaskSchedulerImpl const util::tracing::Span& span() const override { return target_->span(); } private: - bool SubmitTask(std::unique_ptr<Task> task, int latched_cost) { + bool SubmitTask(std::unique_ptr<Task> task, int latched_cost, bool in_continue) { // Wrap the task with a wrapper that runs it and then checks to see if there are any // queued tasks std::string_view name = task->name(); return target_->AddSimpleTask( - [latched_cost, inner_task = std::move(task), + [latched_cost, in_continue, inner_task = std::move(task), self = shared_from_this()]() mutable -> Result<Future<>> { ARROW_ASSIGN_OR_RAISE(Future<> inner_fut, (*inner_task)()); - return inner_fut.Then([latched_cost, self = std::move(self)] { + if (!inner_fut.TryAddCallback([&] { + return [latched_cost, self = std::move(self)](const Status& st) -> void { + if (st.ok()) { + self->throttle_->Release(latched_cost); + self->ContinueTasks(); + } + }; + })) { + // If the task is already finished then don't run ContinueTasks + // if we are already running it so we can avoid stack overflow self->throttle_->Release(latched_cost); - self->ContinueTasks(); - }); + if (!in_continue) { + self->ContinueTasks(); + } + } + return inner_fut; }, name); } @@ -371,7 +383,7 @@ class ThrottledAsyncTaskSchedulerImpl } else { std::unique_ptr<Task> next_task = queue_->Pop(); lk.unlock(); - if (!SubmitTask(std::move(next_task), next_cost)) { + if (!SubmitTask(std::move(next_task), next_cost, /*in_continue=*/true)) { return; } lk.lock(); diff --git a/cpp/src/arrow/util/async_util_test.cc b/cpp/src/arrow/util/async_util_test.cc index 119ca7aa42..7734b84c9e 100644 --- a/cpp/src/arrow/util/async_util_test.cc +++ b/cpp/src/arrow/util/async_util_test.cc @@ -595,6 +595,30 @@ TEST(AsyncTaskScheduler, ScanningStress) { ASSERT_EQ(kExpectedBatchesScanned, batches_scanned.load()); } } + +TEST(AsyncTaskScheduler, ThrottleStress) { + // Queue up a bunch of throttled fast tasks. It shouldn't cause stack overflow + constexpr int kNumTasks = 1024 * 10; + int num_tasks_run = 0; + Future<> slow_task = Future<>::Make(); + Future<> finished = AsyncTaskScheduler::Make([&](AsyncTaskScheduler* scheduler) { + std::shared_ptr<ThrottledAsyncTaskScheduler> throttled = + ThrottledAsyncTaskScheduler::Make(scheduler, 1); + EXPECT_TRUE(throttled->AddSimpleTask([slow_task] { return slow_task; }, kDummyName)); + for (int task_idx = 0; task_idx < kNumTasks; task_idx++) { + throttled->AddSimpleTask( + [&] { + num_tasks_run++; + return Future<>::MakeFinished(); + }, + kDummyName); + } + return Status::OK(); + }); + slow_task.MarkFinished(); + ASSERT_FINISHES_OK(finished); + ASSERT_EQ(kNumTasks, num_tasks_run); +} #endif class TaskWithPriority : public AsyncTaskScheduler::Task {
