This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a2ef144ea3 Refactor RewriteTensorize to prevent concurrent map updates 
(#11596)
a2ef144ea3 is described below

commit a2ef144ea3aa8ae763c59cc596e73d6a89b3f046
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Jun 7 00:57:59 2022 -0700

    Refactor RewriteTensorize to prevent concurrent map updates (#11596)
---
 src/meta_schedule/postproc/rewrite_tensorize.cc | 30 ++++++++++++++-----------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc 
b/src/meta_schedule/postproc/rewrite_tensorize.cc
index 1ad394e49c..3df9075972 100644
--- a/src/meta_schedule/postproc/rewrite_tensorize.cc
+++ b/src/meta_schedule/postproc/rewrite_tensorize.cc
@@ -28,10 +28,10 @@ namespace meta_schedule {
 using tir::BlockRV;
 using tir::LoopRV;
 
-void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
-                        const tir::PrimFuncNode* func, bool 
vectorize_init_loop) {
-  std::vector<std::pair<std::string, std::function<void(tir::BlockRV)>>> jobs;
-
+void CollectTensorizationJobs(
+    const tir::Schedule& sch, const String& func_name, const 
tir::PrimFuncNode* func,
+    bool vectorize_init_loop,
+    std::vector<std::tuple<String, String, 
std::function<void(tir::BlockRV)>>>* jobs) {
   tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
     if (const auto* block = obj.as<tir::BlockNode>()) {
       tir::StmtSRef block_sref = sch->GetSRef(block);
@@ -39,7 +39,7 @@ void ApplyTensorization(const tir::Schedule& sch, const 
String& func_name,
               tir::GetAnn<String>(block_sref, 
tir::attr::meta_schedule_auto_tensorize)) {
         std::string block_name = 
block_sref->StmtAs<tir::BlockNode>()->name_hint;
         if (block_name.find("init") == std::string::npos) {
-          jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) 
{
+          jobs->emplace_back(block_name, func_name, [sch, 
intrin_name](tir::BlockRV block) {
             try {
               sch->Tensorize(block, intrin_name.value());
             } catch (const std::exception& e) {
@@ -47,7 +47,7 @@ void ApplyTensorization(const tir::Schedule& sch, const 
String& func_name,
             }
           });
         } else if (vectorize_init_loop) {
-          jobs.emplace_back(block_name, [sch](tir::BlockRV block) {
+          jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
             Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
             ICHECK(child_blocks.size() == 1);
             Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
@@ -58,12 +58,6 @@ void ApplyTensorization(const tir::Schedule& sch, const 
String& func_name,
       }
     }
   });
-
-  for (auto kv : jobs) {
-    tir::BlockRV block = sch->GetBlock(kv.first, func_name);
-    sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
-    kv.second(block);
-  }
 }
 
 class RewriteTensorizeNode : public PostprocNode {
@@ -81,13 +75,23 @@ class RewriteTensorizeNode : public PostprocNode {
 };
 
 bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
+  // The rewriting jobs, 3-tuple (block_name, func_name, job_func)
+  std::vector<std::tuple<String, String, std::function<void(tir::BlockRV)>>> 
jobs;
   for (const auto& kv : sch->mod()->functions) {
     GlobalVar g_var = kv.first;
     BaseFunc base_func = kv.second;
     if (const tir::PrimFuncNode* prim_func = 
base_func.as<tir::PrimFuncNode>()) {
-      ApplyTensorization(sch, g_var->name_hint, prim_func, 
vectorize_init_loop);
+      CollectTensorizationJobs(sch, g_var->name_hint, prim_func, 
vectorize_init_loop, &jobs);
     }
   }
+  for (const auto& job : jobs) {
+    const String& block_name = std::get<0>(job);
+    const String& func_name = std::get<1>(job);
+    const auto& job_func = std::get<2>(job);
+    BlockRV block = sch->GetBlock(block_name, func_name);
+    sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
+    job_func(block);
+  }
   return true;
 }
 

Reply via email to