manupa-arm commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r640833354



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool 
legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", 
Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass 
to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    // TODO(electriclilies): is there a cleaner way to do this?
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of 
the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or 
the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For 
now I have coped the
+  // python behavior exactly.
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (legacy_te_pass) {

Review comment:
       Yes, OK having looked at whats being removed -- maybe this is out of 
scope for the PR.
   
   My original question was about the motivation of custom passes -- 
specifically at the four possible locations they get inserted (known as 
phases). If it has some meaning, might worth putting a comment.
   
   Then, the next one was given the feature of adding custom passes -- I was 
thinking we could rather have custom pass pipeline registered/provided with the 
proximity for the target. Thus, some could create a new target and may be 
re-use , re-organize the passes that needs running.
   
   Again now I feel thats out of scope for this PR -- as this essentially 
mimics what the python lower is doing. :) 
   
   
   

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool 
legacy_te_pass) {

Review comment:
       Maybe we could be more specific in the name ? e.g. using such as "bool 
from_te_schedule" ?
   
   Also I see simple_mode is just enabling LoopPartition, we could use "bool 
enable_loop_partition" ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to