This is an automated email from the ASF dual-hosted git repository.
zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new e9ec38396e GH-45266: [C++][Acero] Fix the running tasks count of
Scheduler when get error tasks in multi-threads (#45268)
e9ec38396e is described below
commit e9ec38396ed66cd34fea2884fc7ebbe0281a05ef
Author: Yunpei Zhou <[email protected]>
AuthorDate: Wed Feb 12 12:37:23 2025 +0800
GH-45266: [C++][Acero] Fix the running tasks count of Scheduler when get
error tasks in multi-threads (#45268)
### Rationale for this change
When the TaskGroup should be canceled, it will move the number which
not-start to finished to avoid do them(in `TaskSchedulerImpl::Abort`). But this
is one operation that happens in multi-threads. At the same time, maybe some
task start to running and happen some error. Then they will return the bad
status.
But the tasks are running for Scheduler, they will just return bad status
and not change the running_task count. Because the code uses `RETURN_NOT_OK`.
### What changes are included in this PR?
For any task, what status weather it returns, it will change the
running_count before return.
### Are these changes tested?
No. It is too hard to build ut.
### Are there any user-facing changes?
No. But I am very shocked at hasn't this happened to anyone?
* GitHub Issue: #45266
Lead-authored-by: zhouyunpei <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/acero/task_util.cc | 87 +++++++++++++++++++--------------
cpp/src/arrow/acero/task_util_test.cc | 92 +++++++++++++++++++++++++++++++++++
2 files changed, 143 insertions(+), 36 deletions(-)
diff --git a/cpp/src/arrow/acero/task_util.cc b/cpp/src/arrow/acero/task_util.cc
index 85378eaeeb..082ec99946 100644
--- a/cpp/src/arrow/acero/task_util.cc
+++ b/cpp/src/arrow/acero/task_util.cc
@@ -91,11 +91,11 @@ class TaskSchedulerImpl : public TaskScheduler {
AbortContinuationImpl abort_cont_impl_;
std::vector<TaskGroup> task_groups_;
- bool aborted_;
bool register_finished_;
std::mutex mutex_; // Mutex protecting task_groups_ (state_ and
num_tasks_present_
- // fields), aborted_ flag and register_finished_ flag
+ // fields) and register_finished_ flag
+ AtomicWithPadding<bool> aborted_;
AtomicWithPadding<int> num_tasks_to_schedule_;
// If a task group adds tasks it's possible for a thread inside
// ScheduleMore to miss this fact. This serves as a flag to
@@ -105,10 +105,8 @@ class TaskSchedulerImpl : public TaskScheduler {
};
TaskSchedulerImpl::TaskSchedulerImpl()
- : use_sync_execution_(false),
- num_concurrent_tasks_(0),
- aborted_(false),
- register_finished_(false) {
+ : use_sync_execution_(false), num_concurrent_tasks_(0),
register_finished_(false) {
+ aborted_.value.store(false);
num_tasks_to_schedule_.value.store(0);
tasks_added_recently_.value.store(false);
}
@@ -131,13 +129,11 @@ Status TaskSchedulerImpl::StartTaskGroup(size_t
thread_id, int group_id,
ARROW_DCHECK(group_id >= 0 && group_id <
static_cast<int>(task_groups_.size()));
TaskGroup& task_group = task_groups_[group_id];
- bool aborted = false;
+ bool aborted = aborted_.value.load();
bool all_tasks_finished = false;
{
std::lock_guard<std::mutex> lock(mutex_);
- aborted = aborted_;
-
if (task_group.state_ == TaskGroupState::NOT_READY) {
task_group.num_tasks_present_ = total_num_tasks;
if (total_num_tasks == 0) {
@@ -212,7 +208,7 @@ std::vector<std::pair<int, int64_t>>
TaskSchedulerImpl::PickTasks(int num_tasks,
Status TaskSchedulerImpl::ExecuteTask(size_t thread_id, int group_id, int64_t
task_id,
bool* task_group_finished) {
- if (!aborted_) {
+ if (!aborted_.value.load()) {
RETURN_NOT_OK(task_groups_[group_id].task_impl_(thread_id, task_id));
}
*task_group_finished = PostExecuteTask(thread_id, group_id);
@@ -228,11 +224,10 @@ bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id,
int group_id) {
Status TaskSchedulerImpl::OnTaskGroupFinished(size_t thread_id, int group_id,
bool* all_task_groups_finished) {
- bool aborted = false;
+ bool aborted = aborted_.value.load();
{
std::lock_guard<std::mutex> lock(mutex_);
- aborted = aborted_;
TaskGroup& task_group = task_groups_[group_id];
task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
*all_task_groups_finished = true;
@@ -260,7 +255,7 @@ Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int
num_tasks_to_execute
int last_id = 0;
for (;;) {
- if (aborted_) {
+ if (aborted_.value.load()) {
return Status::Cancelled("Scheduler cancelled");
}
@@ -278,8 +273,8 @@ Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int
num_tasks_to_execute
bool task_group_finished = false;
Status status = ExecuteTask(thread_id, group_id, task_id,
&task_group_finished);
if (!status.ok()) {
- // Mark the remaining picked tasks as finished
- for (size_t j = i + 1; j < tasks.size(); ++j) {
+ // Mark the current and remaining picked tasks as finished
+ for (size_t j = i; j < tasks.size(); ++j) {
if (PostExecuteTask(thread_id, tasks[j].first)) {
bool all_task_groups_finished = false;
RETURN_NOT_OK(
@@ -328,7 +323,7 @@ Status TaskSchedulerImpl::StartScheduling(size_t thread_id,
ScheduleImpl schedul
}
Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int
num_tasks_finished) {
- if (aborted_) {
+ if (aborted_.value.load()) {
return Status::Cancelled("Scheduler cancelled");
}
@@ -369,17 +364,25 @@ Status TaskSchedulerImpl::ScheduleMore(size_t thread_id,
int num_tasks_finished)
int group_id = tasks[i].first;
int64_t task_id = tasks[i].second;
RETURN_NOT_OK(schedule_impl_([this, group_id, task_id](size_t thread_id)
-> Status {
- RETURN_NOT_OK(ScheduleMore(thread_id, 1));
-
bool task_group_finished = false;
- RETURN_NOT_OK(ExecuteTask(thread_id, group_id, task_id,
&task_group_finished));
+ // PostExecuteTask must be called later if any error ocurres during task
execution
+ // (including ScheduleMore), so we preserve the status.
+ auto status = [&]() {
+ RETURN_NOT_OK(ScheduleMore(thread_id, 1));
+ return ExecuteTask(thread_id, group_id, task_id, &task_group_finished);
+ }();
+
+ if (!status.ok()) {
+ task_group_finished = PostExecuteTask(thread_id, group_id);
+ }
if (task_group_finished) {
bool all_task_groups_finished = false;
- return OnTaskGroupFinished(thread_id, group_id,
&all_task_groups_finished);
+ RETURN_NOT_OK(
+ OnTaskGroupFinished(thread_id, group_id,
&all_task_groups_finished));
}
- return Status::OK();
+ return status;
}));
}
@@ -388,31 +391,43 @@ Status TaskSchedulerImpl::ScheduleMore(size_t thread_id,
int num_tasks_finished)
void TaskSchedulerImpl::Abort(AbortContinuationImpl impl) {
bool all_finished = true;
+ DCHECK_EQ(aborted_.value.load(), false);
+ aborted_.value.store(true);
{
std::lock_guard<std::mutex> lock(mutex_);
- aborted_ = true;
abort_cont_impl_ = std::move(impl);
if (register_finished_) {
for (size_t i = 0; i < task_groups_.size(); ++i) {
TaskGroup& task_group = task_groups_[i];
- if (task_group.state_ == TaskGroupState::NOT_READY) {
- task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
- } else if (task_group.state_ == TaskGroupState::READY) {
- int64_t expected = task_group.num_tasks_started_.value.load();
- for (;;) {
- if (task_group.num_tasks_started_.value.compare_exchange_strong(
- expected, task_group.num_tasks_present_)) {
- break;
+ switch (task_group.state_) {
+ case TaskGroupState::NOT_READY: {
+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
+ break;
+ }
+ case TaskGroupState::READY: {
+ int64_t expected = task_group.num_tasks_started_.value.load();
+ for (;;) {
+ if (task_group.num_tasks_started_.value.compare_exchange_strong(
+ expected, task_group.num_tasks_present_)) {
+ break;
+ }
}
+ int64_t before_add =
task_group.num_tasks_finished_.value.fetch_add(
+ task_group.num_tasks_present_ - expected);
+ if (before_add >= expected) {
+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
+ } else {
+ all_finished = false;
+ task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
+ }
+ break;
}
- int64_t before_add = task_group.num_tasks_finished_.value.fetch_add(
- task_group.num_tasks_present_ - expected);
- if (before_add >= expected) {
- task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
- } else {
+ case TaskGroupState::ALL_TASKS_STARTED: {
all_finished = false;
- task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
+ break;
}
+ default:
+ break;
}
}
}
diff --git a/cpp/src/arrow/acero/task_util_test.cc
b/cpp/src/arrow/acero/task_util_test.cc
index d5196ad4e0..30f80012e5 100644
--- a/cpp/src/arrow/acero/task_util_test.cc
+++ b/cpp/src/arrow/acero/task_util_test.cc
@@ -231,5 +231,97 @@ TEST(TaskScheduler, StressTwo) {
}
}
+TEST(TaskScheduler, AbortContOnTaskErrorSerial) {
+ constexpr int kNumTasks = 16;
+
+ auto scheduler = TaskScheduler::Make();
+ auto task = [&](std::size_t, int64_t task_id) {
+ if (task_id == kNumTasks / 2) {
+ return Status::Invalid("Task failed");
+ }
+ return Status::OK();
+ };
+
+ int task_group =
+ scheduler->RegisterTaskGroup(task, [](std::size_t) { return
Status::OK(); });
+ scheduler->RegisterEnd();
+
+ ASSERT_OK(scheduler->StartScheduling(
+ /*thread_id=*/0,
+ /*schedule_impl=*/
+ [](TaskScheduler::TaskGroupContinuationImpl) { return Status::OK(); },
+ /*num_concurrent_tasks=*/1, /*use_sync_execution=*/true));
+ ASSERT_RAISES_WITH_MESSAGE(
+ Invalid, "Invalid: Task failed",
+ scheduler->StartTaskGroup(/*thread_id=*/0, task_group, kNumTasks));
+
+ int num_abort_cont_calls = 0;
+ auto abort_cont = [&]() { ++num_abort_cont_calls; };
+
+ scheduler->Abort(abort_cont);
+
+ ASSERT_EQ(num_abort_cont_calls, 1);
+}
+
+TEST(TaskScheduler, AbortContOnTaskErrorParallel) {
+#ifndef ARROW_ENABLE_THREADING
+ GTEST_SKIP() << "Test requires threading support";
+#endif
+ constexpr int kNumThreads = 16;
+
+ ThreadIndexer thread_indexer;
+ int num_threads = std::min(static_cast<int>(thread_indexer.Capacity()),
kNumThreads);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<ThreadPool> thread_pool,
+ MakePrimedThreadPool(num_threads));
+ TaskScheduler::ScheduleImpl schedule =
+ [&](TaskScheduler::TaskGroupContinuationImpl task) {
+ return thread_pool->Spawn([&, task] {
+ std::size_t thread_id = thread_indexer();
+ auto status = task(thread_id);
+ ASSERT_TRUE(status.ok() || status.IsInvalid() ||
status.IsCancelled())
+ << status;
+ });
+ };
+
+ for (int num_tasks :
+ {2, num_threads - 1, num_threads, num_threads + 1, 2 * num_threads}) {
+ ARROW_SCOPED_TRACE("num_tasks = ", num_tasks);
+ for (int num_concurrent_tasks :
+ {1, num_tasks - 1, num_tasks, num_tasks + 1, 2 * num_tasks}) {
+ ARROW_SCOPED_TRACE("num_concurrent_tasks = ", num_concurrent_tasks);
+ for (int aborting_task_id = 0; aborting_task_id < num_tasks;
++aborting_task_id) {
+ ARROW_SCOPED_TRACE("aborting_task_id = ", aborting_task_id);
+ auto scheduler = TaskScheduler::Make();
+
+ int num_abort_cont_calls = 0;
+ auto abort_cont = [&]() { ++num_abort_cont_calls; };
+
+ auto task = [&](std::size_t, int64_t task_id) {
+ if (task_id == aborting_task_id) {
+ scheduler->Abort(abort_cont);
+ }
+ if (task_id % 2 == 0) {
+ return Status::Invalid("Task failed");
+ }
+ return Status::OK();
+ };
+
+ int task_group =
+ scheduler->RegisterTaskGroup(task, [](std::size_t) { return
Status::OK(); });
+ scheduler->RegisterEnd();
+
+ ASSERT_OK(scheduler->StartScheduling(/*thread_id=*/0, schedule,
+ num_concurrent_tasks,
+ /*use_sync_execution=*/false));
+ ASSERT_OK(scheduler->StartTaskGroup(/*thread_id=*/0, task_group,
num_tasks));
+
+ thread_pool->WaitForIdle();
+
+ ASSERT_EQ(num_abort_cont_calls, 1);
+ }
+ }
+ }
+}
+
} // namespace acero
} // namespace arrow