This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 748882aae7 [Runtime] Parallel-for with threading backend (#16133)
748882aae7 is described below

commit 748882aae7be1435f042e22b0fc67cb236705b6c
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Nov 16 05:20:44 2023 -0500

    [Runtime] Parallel-for with threading backend (#16133)
    
    This PR introduces the runtime parallel-for helper function
    in C++ with the threading backend in TVM.
    
    Right now the existing 
[parallel-for](https://github.com/apache/tvm/blob/bd67d2e5ebde1aec18bcfa74c087516579bda1ae/include/tvm/support/parallel_for.h#L48-L68)
    in TVM is not thread persistent,
    in which case we cannot get persistent TLS for each thread.
    
    The introduced parallel-for-with-threading-backend function
    leverages the threading backend in TVM and persists threads.
---
 include/tvm/runtime/threading_backend.h | 70 +++++++++++++++++++++++++++++++++
 tests/cpp/threading_backend_test.cc     |  9 +++++
 2 files changed, 79 insertions(+)

diff --git a/include/tvm/runtime/threading_backend.h 
b/include/tvm/runtime/threading_backend.h
index 77d6730c09..3122b000e0 100644
--- a/include/tvm/runtime/threading_backend.h
+++ b/include/tvm/runtime/threading_backend.h
@@ -24,6 +24,9 @@
 #ifndef TVM_RUNTIME_THREADING_BACKEND_H_
 #define TVM_RUNTIME_THREADING_BACKEND_H_
 
+#include <tvm/runtime/c_backend_api.h>
+
+#include <algorithm>
 #include <functional>
 #include <memory>
 #include <vector>
@@ -147,6 +150,73 @@ TVM_DLL void 
Configure(tvm::runtime::threading::ThreadGroup::AffinityMode mode,
 int32_t NumThreads();
 
 }  // namespace threading
+
+/*!
+ * \brief Execute the given lambda function in parallel with
+ * threading backend in TVM.
+ * \tparam T The type of the lambda: "void (int i)".
+ * \param flambda The lambda to be executed in parallel.
+ * It should have the signature "void (int i)".
+ * \param begin The start index of this parallel loop (inclusive).
+ * \param end The end index of this parallel loop (exclusive).
+ * \example
+ *
+ * The for loop
+ *   for (int i = 0; i < 10; i++) {
+ *     a[i] = i;
+ *   }
+ * should work the same as:
+ *   parallel_for_with_threading_backend([&a](int i) {
+ *     a[i] = i;
+ *   }, 0, 10);
+ */
+template <typename T>
+inline void parallel_for_with_threading_backend(T flambda, int64_t begin, 
int64_t end);
+
+namespace detail {
+
+// The detailed implementation of `parallel_for_with_threading_backend`.
+// To avoid template expansion, the implementation cannot be placed
+// in .cc files.
+
+template <typename T>
+struct ParallelForWithThreadingBackendLambdaInvoker {
+  static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv, 
void* cdata) {
+    int num_task = penv->num_task;
+    // Convert void* back to lambda type.
+    T* lambda_ptr = static_cast<T*>(cdata);
+    // Invoke the lambda with the task id (thread id).
+    (*lambda_ptr)(task_id, num_task);
+    return 0;
+  }
+};
+
+template <typename T>
+inline void parallel_launch_with_threading_backend(T flambda) {
+  // Launch the lambda by passing its address.
+  void* cdata = &flambda;
+  
TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker<T>::TVMParallelLambdaInvoke,
+                           cdata, /*num_task=*/0);
+}
+
+}  // namespace detail
+
+template <typename T>
+inline void parallel_for_with_threading_backend(T flambda, int64_t begin, 
int64_t end) {
+  auto flaunch = [begin, end, flambda](int task_id, int num_task) {
+    // For each thread, do static division and call into flambda.
+    int64_t total_len = end - begin;
+    int64_t step = (total_len + num_task - 1) / num_task;
+    int64_t local_begin = std::min(begin + step * task_id, end);
+    int64_t local_end = std::min(local_begin + step, end);
+    for (int64_t i = local_begin; i < local_end; ++i) {
+      flambda(i);
+    }
+  };
+  // Launch with all threads.
+  detail::parallel_launch_with_threading_backend(flaunch);
+}
+
 }  // namespace runtime
 }  // namespace tvm
 
diff --git a/tests/cpp/threading_backend_test.cc 
b/tests/cpp/threading_backend_test.cc
index 5adf1f9ae3..b156eec8ab 100644
--- a/tests/cpp/threading_backend_test.cc
+++ b/tests/cpp/threading_backend_test.cc
@@ -185,3 +185,12 @@ TEST(ThreadingBackend, TVMBackendAffinityConfigure) {
     t->join();
   }
 }
+
+TEST(ThreadingBackend, TVMBackendParallelForWithThreadingBackend) {
+  int n = 100;
+  std::vector<int> vec(/*size=*/n, /*value=*/0);
+  tvm::runtime::parallel_for_with_threading_backend([&vec](int i) { vec[i] = 
i; }, 0, n);
+  for (int i = 0; i < n; ++i) {
+    EXPECT_EQ(vec[i], i);
+  }
+}

Reply via email to