This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 748882aae7 [Runtime] Parallel-for with threading backend (#16133)
748882aae7 is described below
commit 748882aae7be1435f042e22b0fc67cb236705b6c
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Nov 16 05:20:44 2023 -0500
[Runtime] Parallel-for with threading backend (#16133)
This PR introduces the runtime parallel-for helper function
in C++ with the threading backend in TVM.
Right now the existing
[parallel-for](https://github.com/apache/tvm/blob/bd67d2e5ebde1aec18bcfa74c087516579bda1ae/include/tvm/support/parallel_for.h#L48-L68)
in TVM is not thread persistent,
in which case we cannot get persistent TLS for each thread.
The introduced parallel-for-with-threading-backend function
leverages the threading backend in TVM and persists threads.
---
include/tvm/runtime/threading_backend.h | 70 +++++++++++++++++++++++++++++++++
tests/cpp/threading_backend_test.cc | 9 +++++
2 files changed, 79 insertions(+)
diff --git a/include/tvm/runtime/threading_backend.h
b/include/tvm/runtime/threading_backend.h
index 77d6730c09..3122b000e0 100644
--- a/include/tvm/runtime/threading_backend.h
+++ b/include/tvm/runtime/threading_backend.h
@@ -24,6 +24,9 @@
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_
#define TVM_RUNTIME_THREADING_BACKEND_H_
+#include <tvm/runtime/c_backend_api.h>
+
+#include <algorithm>
#include <functional>
#include <memory>
#include <vector>
@@ -147,6 +150,73 @@ TVM_DLL void
Configure(tvm::runtime::threading::ThreadGroup::AffinityMode mode,
int32_t NumThreads();
} // namespace threading
+
+/*!
+ * \brief Execute the given lambda function in parallel with
+ * threading backend in TVM.
+ * \tparam T The type of the lambda: "void (int i)".
+ * \param flambda The lambda to be executed in parallel.
+ * It should have the signature "void (int i)".
+ * \param begin The start index of this parallel loop (inclusive).
+ * \param end The end index of this parallel loop (exclusive).
+ * \example
+ *
+ * The for loop
+ * for (int i = 0; i < 10; i++) {
+ * a[i] = i;
+ * }
+ * should work the same as:
+ * parallel_for_with_threading_backend([&a](int i) {
+ * a[i] = i;
+ * }, 0, 10);
+ */
+template <typename T>
+inline void parallel_for_with_threading_backend(T flambda, int64_t begin,
int64_t end);
+
+namespace detail {
+
+// The detailed implementation of `parallel_for_with_threading_backend`.
+// To avoid template expansion, the implementation cannot be placed
+// in .cc files.
+
+template <typename T>
+struct ParallelForWithThreadingBackendLambdaInvoker {
+ static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv,
void* cdata) {
+ int num_task = penv->num_task;
+ // Convert void* back to lambda type.
+ T* lambda_ptr = static_cast<T*>(cdata);
+ // Invoke the lambda with the task id (thread id).
+ (*lambda_ptr)(task_id, num_task);
+ return 0;
+ }
+};
+
+template <typename T>
+inline void parallel_launch_with_threading_backend(T flambda) {
+ // Launch the lambda by passing its address.
+ void* cdata = &flambda;
+
TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker<T>::TVMParallelLambdaInvoke,
+ cdata, /*num_task=*/0);
+}
+
+} // namespace detail
+
+template <typename T>
+inline void parallel_for_with_threading_backend(T flambda, int64_t begin,
int64_t end) {
+ auto flaunch = [begin, end, flambda](int task_id, int num_task) {
+ // For each thread, do static division and call into flambda.
+ int64_t total_len = end - begin;
+ int64_t step = (total_len + num_task - 1) / num_task;
+ int64_t local_begin = std::min(begin + step * task_id, end);
+ int64_t local_end = std::min(local_begin + step, end);
+ for (int64_t i = local_begin; i < local_end; ++i) {
+ flambda(i);
+ }
+ };
+ // Launch with all threads.
+ detail::parallel_launch_with_threading_backend(flaunch);
+}
+
} // namespace runtime
} // namespace tvm
diff --git a/tests/cpp/threading_backend_test.cc
b/tests/cpp/threading_backend_test.cc
index 5adf1f9ae3..b156eec8ab 100644
--- a/tests/cpp/threading_backend_test.cc
+++ b/tests/cpp/threading_backend_test.cc
@@ -185,3 +185,12 @@ TEST(ThreadingBackend, TVMBackendAffinityConfigure) {
t->join();
}
}
+
+TEST(ThreadingBackend, TVMBackendParallelForWithThreadingBackend) {
+ int n = 100;
+ std::vector<int> vec(/*size=*/n, /*value=*/0);
+ tvm::runtime::parallel_for_with_threading_backend([&vec](int i) { vec[i] =
i; }, 0, n);
+ for (int i = 0; i < n; ++i) {
+ EXPECT_EQ(vec[i], i);
+ }
+}