This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a2ef144ea3 Refactor RewriteTensorize to prevent concurrent map updates
(#11596)
a2ef144ea3 is described below
commit a2ef144ea3aa8ae763c59cc596e73d6a89b3f046
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Jun 7 00:57:59 2022 -0700
Refactor RewriteTensorize to prevent concurrent map updates (#11596)
---
src/meta_schedule/postproc/rewrite_tensorize.cc | 30 ++++++++++++++-----------
1 file changed, 17 insertions(+), 13 deletions(-)
diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc
b/src/meta_schedule/postproc/rewrite_tensorize.cc
index 1ad394e49c..3df9075972 100644
--- a/src/meta_schedule/postproc/rewrite_tensorize.cc
+++ b/src/meta_schedule/postproc/rewrite_tensorize.cc
@@ -28,10 +28,10 @@ namespace meta_schedule {
using tir::BlockRV;
using tir::LoopRV;
-void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
- const tir::PrimFuncNode* func, bool
vectorize_init_loop) {
- std::vector<std::pair<std::string, std::function<void(tir::BlockRV)>>> jobs;
-
+void CollectTensorizationJobs(
+ const tir::Schedule& sch, const String& func_name, const
tir::PrimFuncNode* func,
+ bool vectorize_init_loop,
+ std::vector<std::tuple<String, String,
std::function<void(tir::BlockRV)>>>* jobs) {
tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
if (const auto* block = obj.as<tir::BlockNode>()) {
tir::StmtSRef block_sref = sch->GetSRef(block);
@@ -39,7 +39,7 @@ void ApplyTensorization(const tir::Schedule& sch, const
String& func_name,
tir::GetAnn<String>(block_sref,
tir::attr::meta_schedule_auto_tensorize)) {
std::string block_name =
block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
- jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block)
{
+ jobs->emplace_back(block_name, func_name, [sch,
intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
@@ -47,7 +47,7 @@ void ApplyTensorization(const tir::Schedule& sch, const
String& func_name,
}
});
} else if (vectorize_init_loop) {
- jobs.emplace_back(block_name, [sch](tir::BlockRV block) {
+ jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
@@ -58,12 +58,6 @@ void ApplyTensorization(const tir::Schedule& sch, const
String& func_name,
}
}
});
-
- for (auto kv : jobs) {
- tir::BlockRV block = sch->GetBlock(kv.first, func_name);
- sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
- kv.second(block);
- }
}
class RewriteTensorizeNode : public PostprocNode {
@@ -81,13 +75,23 @@ class RewriteTensorizeNode : public PostprocNode {
};
bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
+ // The rewriting jobs, 3-tuple (block_name, func_name, job_func)
+ std::vector<std::tuple<String, String, std::function<void(tir::BlockRV)>>>
jobs;
for (const auto& kv : sch->mod()->functions) {
GlobalVar g_var = kv.first;
BaseFunc base_func = kv.second;
if (const tir::PrimFuncNode* prim_func =
base_func.as<tir::PrimFuncNode>()) {
- ApplyTensorization(sch, g_var->name_hint, prim_func,
vectorize_init_loop);
+ CollectTensorizationJobs(sch, g_var->name_hint, prim_func,
vectorize_init_loop, &jobs);
}
}
+ for (const auto& job : jobs) {
+ const String& block_name = std::get<0>(job);
+ const String& func_name = std::get<1>(job);
+ const auto& job_func = std::get<2>(job);
+ BlockRV block = sch->GetBlock(block_name, func_name);
+ sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
+ job_func(block);
+ }
return true;
}