csullivan commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646912412



##########
File path: python/tvm/autotvm/feature.py
##########
@@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True):
     """Do lower while keeping all axes in IR
     i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or 
inject virtual threads
     """
-    binds, _ = build_module.get_binds(args, binds)
+    binds, _ = build_module.get_binds(args, False, binds)

Review comment:
       ```suggestion
       binds, _ = build_module.get_binds(args, compact=False, binds=binds)
   ```
   Preference for keyword use at call site for readability

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& 
name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
-TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds);
 
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, 
tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a schedule, args and binds
+ * \param sch The schedule to lower.
+ * \param args The arguments to the function (Array of Union of Tensor, Buffer 
and Vars)

Review comment:
       ```suggestion
    * \param args The arguments to the function (Array of Tensor, Buffer and 
Vars)
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool 
for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", 
Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass 
to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of 
the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), 
user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), 
user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), 
user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> 
pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const 
String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, 
kv.second));
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, 
bool simple_mode) {
+  return LowerModule(mod, simple_mode);
+});
+
+IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool 
simple_mode) {
+  auto pass_ctx = transform::PassContext::Current();
+  auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+
+  // Get the pass list
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_primfunc")
+    .set_body_typed([](te::PrimFunc func, const String& name, bool 
simple_mode) {
+      return LowerPrimFunc(func, name, simple_mode);
+    });
+
+IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds, bool simple_mode) {
+  Array<ObjectRef> ref_args;
+  for (auto x : args) {
+    ref_args.push_back(x);
+  }
+  return LowerSchedule(sch, ref_args, name, binds);
+}
+
+IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const 
std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds, bool simple_mode) {
+  IRModule mod = ScheduleToModule(sch, args, name, binds);
+  // Get the legacy TE pass list
+  auto pass_list = CreatePassList(simple_mode, true);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_schedule")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const 
String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds, bool 
simple_mode) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {

Review comment:
       ```suggestion
         if (binds.get() != nullptr) {
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool 
legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", 
Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();

Review comment:
       It would be nice to take the refactoring opportunity to add some 
documentation on this. Though completely understand you have mirrored the 
previous lower impl.

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds

Review comment:
       ```suggestion
    * \brief Build an IRModule given a module
   ```
   Updated the comment to reflect the API. Please check the docs for the other 
APIs that don't take args/binds as well.

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool 
for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", 
Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass 
to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of 
the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), 
user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), 
user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse

Review comment:
       ```suggestion
   ```
   Either add more details or this comment isn't needed.

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool 
for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", 
Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass 
to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of 
the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), 
user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), 
user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), 
user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> 
pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const 
String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, 
kv.second));

Review comment:
       ```suggestion
             c_binds.insert({kv.first, kv.second});
   ```

##########
File path: python/tvm/driver/build_module.py
##########
@@ -37,92 +37,54 @@
 from tvm.tir.buffer import Buffer
 from tvm.tir.expr import Var
 
+from . import _ffi_api as ffi
+
 
 def get_binds(args, compact=False, binds=None):
     """Internal function to get binds and arg_list given arguments.
-
     Parameters
     ----------
     args : list of Buffer or Tensor or Var
         The argument lists to the function.
-
     compact : bool
         If the statement has already bound to a compact buffer.
-
     binds : dict of :any:`Tensor` to :any:`Buffer`, optional
         Dictionary that maps the Tensor to Buffer which specified the data 
layout
         requirement of the function. By default, a new compact buffer is 
created
         for each tensor in the argument.
-
     Returns
     -------
     binds: dict
         The bind specification
-
     arg_list: list
         The list of symbolic buffers of arguments.
     """
-    binds = {} if binds is None else binds.copy()
-    arg_list = []
-    for x in args:
-        if isinstance(x, tensor.Tensor):
-            any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
-            buffer_type = "auto_broadcast" if any_dim and not compact else ""
-            if x not in binds:
-                buf = tvm.tir.decl_buffer(
-                    x.shape, dtype=x.dtype, name=x.name, 
buffer_type=buffer_type
-                )
-                binds[x] = buf
-                arg_list.append(buf)
-            else:
-                arg_list.append(binds[x])
-        elif isinstance(x, schedule.Buffer):
-            arg_list.append(x)
-        elif isinstance(x, tvm.tir.Var):
-            arg_list.append(x)
-        else:
-            raise ValueError("args must be Tensor, Buffer or Var")
-    return binds, arg_list
-
-
-def form_irmodule(sch, args, name, binds):
-    """According to the given schedule, form a function.
+    out_arr = ffi.get_binds(args, compact, binds)
+    return out_arr[0], out_arr[1]
+

Review comment:
       Probably we should update the doc string to match the return variable 
names. Actually the reverse would be preferable, `binds, arg_list = 
ffi.get_binds(...)` but I understand you are limited here by the object system.

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,

Review comment:
       Is there a reason to include outputs as input references other than 
needing multiple return values?  I couldn't see one from the use of GetBinds in 
your PR but maybe I missed it.
   
   If not we could return std::pair/tuple with std::tie and later upgrade to 
structured bindings if we move to c++17.
   ```
   // now
   Array<ObjectRef> out_arg_list;
   Map<te::Tensor, tir::Buffer> out_binds;
   std::tie(out_binds, out_arg_list) = GetBinds(args, compact, binds);
   
   // c++17
   auto [out_binds, out_arg_list] = GetBinds(args, compact, binds);
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* 
out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, 
x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, 
te::Buffer or tir::Var";
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("driver.get_binds")
+    .set_body_typed([](const Array<ObjectRef>& args, bool compact,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {

Review comment:
       ```suggestion
         if (binds.get() != nullptr) {
   ```
   

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* 
out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, 
x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, 
te::Buffer or tir::Var";
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("driver.get_binds")
+    .set_body_typed([](const Array<ObjectRef>& args, bool compact,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, 
kv.second));
+        }
+      }
+      Map<te::Tensor, tir::Buffer> out_binds;
+      Array<ObjectRef> out_arg_list;
+      GetBinds(args, compact, c_binds, &out_binds, &out_arg_list);
+
+      // TVM object system doesn't have a pair object, so we'll put both ret 
values in an array
+      // and return that.
+      Array<ObjectRef> out_arr = {out_binds, out_arg_list};

Review comment:
       I'm fine with the array impl., but do you know if there is interest to 
add a pair type? I would have guessed there was one already given that we have 
Map. `Map<String, ObjectRef>` could work, albeit at the cost of a string lookup.
   ```
   out = driver.get_binds(...)
   return out["binds"], out["arg_list"]
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* 
out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, 
x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, 
te::Buffer or tir::Var";
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("driver.get_binds")
+    .set_body_typed([](const Array<ObjectRef>& args, bool compact,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, 
kv.second));

Review comment:
       ```suggestion
             c_binds.insert({kv.first, kv.second});
   ```

##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -770,7 +770,8 @@ class CompileEngineImpl : public CompileEngineNode {
       With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
       std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, 
cache_node->func_name, binds);
+      cache_node->funcs =
+          tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, 
binds);

Review comment:
       We shouldn't need to route through the `relay.backend.lower` packed 
function any longer right? I'd suggest either removing it and the above 
conditional (line 766) or also reproducing it's functionality in c++ if we do 
need it as part of this PR. This way we avoid an unnecessary round trip through 
python. 

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool 
for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", 
Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass 
to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of 
the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), 
user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), 
user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), 
user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> 
pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const 
String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, 
kv.second));
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, 
bool simple_mode) {
+  return LowerModule(mod, simple_mode);
+});
+
+IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool 
simple_mode) {
+  auto pass_ctx = transform::PassContext::Current();
+  auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+
+  // Get the pass list
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_primfunc")
+    .set_body_typed([](te::PrimFunc func, const String& name, bool 
simple_mode) {
+      return LowerPrimFunc(func, name, simple_mode);
+    });
+
+IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds, bool simple_mode) {
+  Array<ObjectRef> ref_args;
+  for (auto x : args) {
+    ref_args.push_back(x);
+  }
+  return LowerSchedule(sch, ref_args, name, binds);
+}
+
+IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const 
std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& 
binds, bool simple_mode) {
+  IRModule mod = ScheduleToModule(sch, args, name, binds);
+  // Get the legacy TE pass list
+  auto pass_list = CreatePassList(simple_mode, true);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_schedule")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const 
String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds, bool 
simple_mode) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, 
kv.second));

Review comment:
       ```suggestion
             c_binds.insert({kv.first, kv.second});
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const 
std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool 
for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", 
Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass 
to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of 
the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), 
user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), 
user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), 
user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> 
pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, 
const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& 
binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const 
String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {

Review comment:
       ```suggestion
         if (binds.get() != nullptr) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to