This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a1d43c1  [Autoscheduler][VM] Autoscheduler layout rewrite pass to VM 
(#7516)
a1d43c1 is described below

commit a1d43c15ac6382831370c6de141bf80888761e70
Author: Thierry Moreau <[email protected]>
AuthorDate: Mon Mar 1 17:31:57 2021 -0800

    [Autoscheduler][VM] Autoscheduler layout rewrite pass to VM (#7516)
    
    * fix type inference for conv2d
    
    * fix
    
    * adding the autoscheduler layout rewrite pass to VM compiler passes
    
    * revert edits applied in other PR
    
    * minor fix
    
    * fix
    
    * formatting fix
    
    * lint
---
 src/relay/backend/vm/compiler.cc | 18 +++++++++++++++++-
 1 file changed, 17 insertions(+), 1 deletion(-)

diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 7697b59..0718191 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1066,6 +1066,23 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const 
TargetsMap& targets,
   }
 
   pass_seqs.push_back(transform::FuseOps());
+  // Do layout rewrite for auto-scheduler.
+  transform::PassContext pass_ctx = PassContext::Current();
+  if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) {
+    const auto& target = (*targets.begin()).second;
+    Pass major_pass = transform::AutoSchedulerLayoutRewrite();
+    bool enable_layout_rewrite_targets =
+        target->kind->device_type == kDLCPU || 
target->GetAttr<String>("device", "") == "mali";
+    if (enable_layout_rewrite_targets && 
pass_ctx.PassEnabled(major_pass->Info())) {
+      With<Target> tctx(target);
+      pass_seqs.push_back(major_pass);
+      // Defuse ops to fold constants, then fuse them again
+      pass_seqs.push_back(transform::DefuseOps());
+      pass_seqs.push_back(transform::FoldConstant());
+      pass_seqs.push_back(transform::FuseOps());
+    }
+  }
+
   pass_seqs.push_back(transform::ToANormalForm());
   pass_seqs.push_back(transform::InferType());
   pass_seqs.push_back(transform::LambdaLift());
@@ -1082,7 +1099,6 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const 
TargetsMap& targets,
   pass_seqs.push_back(transform::InferType());
 
   transform::Sequential seq(pass_seqs);
-  transform::PassContext pass_ctx = PassContext::Current();
   tvm::With<relay::transform::PassContext> ctx(pass_ctx);
   if (targets.size() == 1) {
     const auto& it = targets.begin();

Reply via email to