This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a1d43c1 [Autoscheduler][VM] Autoscheduler layout rewrite pass to VM
(#7516)
a1d43c1 is described below
commit a1d43c15ac6382831370c6de141bf80888761e70
Author: Thierry Moreau <[email protected]>
AuthorDate: Mon Mar 1 17:31:57 2021 -0800
[Autoscheduler][VM] Autoscheduler layout rewrite pass to VM (#7516)
* fix type inference for conv2d
* fix
* adding the autoscheduler layout rewrite pass to VM compiler passes
* revert edits applied in other PR
* minor fix
* fix
* formatting fix
* lint
---
src/relay/backend/vm/compiler.cc | 18 +++++++++++++++++-
1 file changed, 17 insertions(+), 1 deletion(-)
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 7697b59..0718191 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1066,6 +1066,23 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const
TargetsMap& targets,
}
pass_seqs.push_back(transform::FuseOps());
+ // Do layout rewrite for auto-scheduler.
+ transform::PassContext pass_ctx = PassContext::Current();
+ if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) {
+ const auto& target = (*targets.begin()).second;
+ Pass major_pass = transform::AutoSchedulerLayoutRewrite();
+ bool enable_layout_rewrite_targets =
+ target->kind->device_type == kDLCPU ||
target->GetAttr<String>("device", "") == "mali";
+ if (enable_layout_rewrite_targets &&
pass_ctx.PassEnabled(major_pass->Info())) {
+ With<Target> tctx(target);
+ pass_seqs.push_back(major_pass);
+ // Defuse ops to fold constants, then fuse them again
+ pass_seqs.push_back(transform::DefuseOps());
+ pass_seqs.push_back(transform::FoldConstant());
+ pass_seqs.push_back(transform::FuseOps());
+ }
+ }
+
pass_seqs.push_back(transform::ToANormalForm());
pass_seqs.push_back(transform::InferType());
pass_seqs.push_back(transform::LambdaLift());
@@ -1082,7 +1099,6 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const
TargetsMap& targets,
pass_seqs.push_back(transform::InferType());
transform::Sequential seq(pass_seqs);
- transform::PassContext pass_ctx = PassContext::Current();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
if (targets.size() == 1) {
const auto& it = targets.begin();