manupa-arm commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r639553994
##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
- Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool
legacy_te_pass) {
Review comment:
Clarification : what is legacy_te_pass mean here?
##########
File path: include/tvm/driver/driver_api.h
##########
@@ -48,10 +49,39 @@ namespace tvm {
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
+ * \param simple_mode Skips the LoopPartition pass if true. Defaults to false.
* \return The result module.
*/
TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>&
binds);
+ const std::unordered_map<te::Tensor, tir::Buffer>&
binds,
+ bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param args The arguments to the function.
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \param simple_mode Skips the LoopPartition pass if true. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule lower(IRModule mod, const Array<te::Tensor>& args, const
std::string& name,
+ const std::unordered_map<te::Tensor, tir::Buffer>&
binds,
+ bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
Review comment:
typo : should be a PrimFunc?
##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
- Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool
legacy_te_pass) {
auto pass_ctx = transform::PassContext::Current();
- sch = sch.normalize();
-
- // Before TIR transformation.
- auto bounds = te::InferBound(sch);
- auto stmt = te::ScheduleOps(sch, bounds, false);
- bool compact = te::VerifyCompactBuffer(stmt);
-
- Map<te::Tensor, tir::Buffer> out_binds;
- GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
- // build the function
- tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
- f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
- bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers",
Bool(false)).value();
- if (noalias) {
- f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+ // Get any user-added passes
+ auto add_lower_pass =
+ pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass",
Array<Array<ObjectRef>>())
+ .value();
+
+ auto user_lower_phase0 = Array<tvm::transform::Pass>();
Review comment:
Clarification : What is the importance of having phases and what do they
represent ?
##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
- Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool
legacy_te_pass) {
auto pass_ctx = transform::PassContext::Current();
- sch = sch.normalize();
-
- // Before TIR transformation.
- auto bounds = te::InferBound(sch);
- auto stmt = te::ScheduleOps(sch, bounds, false);
- bool compact = te::VerifyCompactBuffer(stmt);
-
- Map<te::Tensor, tir::Buffer> out_binds;
- GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
- // build the function
- tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
- f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
- bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers",
Bool(false)).value();
- if (noalias) {
- f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+ // Get any user-added passes
+ auto add_lower_pass =
+ pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass",
Array<Array<ObjectRef>>())
+ .value();
+
+ auto user_lower_phase0 = Array<tvm::transform::Pass>();
+ auto user_lower_phase1 = Array<tvm::transform::Pass>();
+ auto user_lower_phase2 = Array<tvm::transform::Pass>();
+ auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+ // phase pasees is of the form
+ // [[phase_number, pass], [phase_number, pass]... ]
+ for (auto phase_pass : add_lower_pass) {
+ auto phase_num = phase_pass[0].as<IntImmNode>();
+ ICHECK(phase_num)
+ << "Expected the first entry in the inner Array of tir.add_lower_pass
to be an integer";
+ int phase_num_val = phase_num->value;
+
+ CHECK_GT(phase_num_val, 0);
+
+ // TODO(electriclilies): is there a cleaner way to do this?
+ auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+ auto pass = GetRef<tvm::transform::Pass>(pass_node);
+ // Copy the pass into the correct phase
+ if (phase_num_val == 0) {
+ user_lower_phase0.push_back(pass);
+ } else if (phase_num_val == 1) {
+ user_lower_phase1.push_back(pass);
+ } else if (phase_num_val == 2) {
+ user_lower_phase2.push_back(pass);
+ } else if (phase_num_val >= 3) {
+ user_lower_phase3.push_back(pass);
+ }
}
- auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
- auto pass_list = Array<tvm::transform::Pass>();
+ // Construct the pass list, inserting the user provided passes at the end of
the phase
+ // TODO(electriclilies): I'm not sure if they should go at the beginning or
the end of the phase.
+ // The code is inconsistent with what passes are in which phase as well. For
now I have coped the
+ // python behavior exactly.
+
+ // PHASE 0
+ auto pass_list = user_lower_phase0;
- // Phase 0
- pass_list.push_back(tir::transform::InjectPrefetch());
- pass_list.push_back(tir::transform::StorageFlatten(64,
instrument_bound_checkers));
- // Phase 1
+ // PHASE 1
+ if (legacy_te_pass) {
Review comment:
@jroesch @CircleSpin
A general design question : What do you think of relocating the pass
pipeline closer to the target registry ? So that way we could have a function
(that can also be a bit parameterized based on target args) that describes pass
pipeline rather than mandating the passes here.
Here, we could just query the pass pipeline based on the target and run
here. I think this can remove the need to support user defined custom passes at
this level.
Thoughts ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]