electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r641219409
##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
- Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool
for_te_schedule) {
Review comment:
Is there a better way to handle the creation of the pass list? Should I
split this function into two, one for the te schedules and one for the IRModule
and primfuncs?
##########
File path: src/driver/driver_api.cc
##########
@@ -93,6 +93,7 @@ tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape,
DataType dtype, std
offset_factor, buffer_type);
}
+// comment to try to remove this
void GetBinds(const Array<te::Tensor>& args, bool compact,
Review comment:
Is having two copies of GetBinds OK?
##########
File path: python/tvm/driver/build_module.py
##########
@@ -160,90 +99,15 @@ def lower(
m : IRModule
The result IRModule
"""
- # config setup
- pass_ctx = PassContext.current()
- instrument_bound_checkers =
bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
- disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize",
False))
- add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])
-
- lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
- lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
- lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
- lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
-
- # Phase 0
- pass_list = lower_phase0
- is_legacy_te_schedule: bool = False
-
- if isinstance(inputs, schedule.Schedule):
- if args is None:
- raise ValueError("args must be given for lowering from TE
schedule")
- mod = form_irmodule(inputs, args, name, binds)
- is_legacy_te_schedule = True
- elif isinstance(inputs, PrimFunc):
- func = inputs.with_attr("global_symbol", name)
- if pass_ctx.config.get("tir.noalias", True):
- func = func.with_attr("tir.noalias", True)
- mod = tvm.IRModule({name: func})
- elif isinstance(inputs, IRModule):
- mod = inputs
- else:
- raise TypeError(
- f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got
{type(inputs)}"
- )
-
- # Phase 1
- if is_legacy_te_schedule:
- pass_list += [
- tvm.tir.transform.InjectPrefetch(),
- tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
- ]
- else:
- pass_list += [
- tvm.tir.transform.LowerInitBlock(),
- tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
- tvm.tir.transform.ConvertBlocksToOpaque(),
- tvm.tir.transform.CompactBufferAllocation(),
- tvm.tir.transform.FlattenBuffer(),
- ]
- pass_list += [
- tvm.tir.transform.BF16Legalize(),
- tvm.tir.transform.NarrowDataType(32),
- tvm.tir.transform.Simplify(),
- ]
-
- pass_list += lower_phase1
-
- # Phase 2
- if not simple_mode:
- pass_list += [(tvm.tir.transform.LoopPartition())]
-
- pass_list += [
- tvm.tir.transform.VectorizeLoop(not disable_vectorize),
- tvm.tir.transform.InjectVirtualThread(),
- tvm.tir.transform.InjectDoubleBuffer(),
- tvm.tir.transform.StorageRewrite(),
- tvm.tir.transform.UnrollLoop(),
- ]
- pass_list += lower_phase2
-
- # Phase 3
- pass_list += [
- tvm.tir.transform.Simplify(),
- tvm.tir.transform.RemoveNoOp(),
- ]
-
- pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
- pass_list += [tvm.tir.transform.HoistIfThenElse()]
- pass_list += lower_phase3
-
- # Instrument BoundCheckers
- if instrument_bound_checkers:
- pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]
-
- optimize = tvm.transform.Sequential(pass_list)
- mod = optimize(mod)
- return mod
+ if isinstance(input, IRModule):
+ return ffi.lower_module(input)
+ if isinstance(input, PrimFunc):
+ return ffi.lower_primfunc(input)
Review comment:
Should I split the python lower function into 3 functions as well or
leave as a single API for now?
##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
- Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool
for_te_schedule) {
auto pass_ctx = transform::PassContext::Current();
- sch = sch.normalize();
-
- // Before TIR transformation.
- auto bounds = te::InferBound(sch);
- auto stmt = te::ScheduleOps(sch, bounds, false);
- bool compact = te::VerifyCompactBuffer(stmt);
-
- Map<te::Tensor, tir::Buffer> out_binds;
- GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
- // build the function
- tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
- f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
- bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers",
Bool(false)).value();
- if (noalias) {
- f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+ // Get any user-added passes
+ auto add_lower_pass =
+ pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass",
Array<Array<ObjectRef>>())
+ .value();
+
+ auto user_lower_phase0 = Array<tvm::transform::Pass>();
+ auto user_lower_phase1 = Array<tvm::transform::Pass>();
+ auto user_lower_phase2 = Array<tvm::transform::Pass>();
+ auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+ // phase pasees is of the form
+ // [[phase_number, pass], [phase_number, pass]... ]
+ for (auto phase_pass : add_lower_pass) {
+ auto phase_num = phase_pass[0].as<IntImmNode>();
+ ICHECK(phase_num)
+ << "Expected the first entry in the inner Array of tir.add_lower_pass
to be an integer";
+ int phase_num_val = phase_num->value;
+
+ CHECK_GT(phase_num_val, 0);
+
+ auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+ auto pass = GetRef<tvm::transform::Pass>(pass_node);
+ // Copy the pass into the correct phase
+ if (phase_num_val == 0) {
+ user_lower_phase0.push_back(pass);
+ } else if (phase_num_val == 1) {
+ user_lower_phase1.push_back(pass);
+ } else if (phase_num_val == 2) {
+ user_lower_phase2.push_back(pass);
+ } else if (phase_num_val >= 3) {
+ user_lower_phase3.push_back(pass);
+ }
}
- auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
- auto pass_list = Array<tvm::transform::Pass>();
+ // Construct the pass list, inserting the user provided passes at the end of
the phase
+ // TODO(electriclilies): I'm not sure if they should go at the beginning or
the end of the phase.
+ // The code is inconsistent with what passes are in which phase as well. For
now I have coped the
+ // python behavior exactly.
- // Phase 0
- pass_list.push_back(tir::transform::InjectPrefetch());
- pass_list.push_back(tir::transform::StorageFlatten(64,
instrument_bound_checkers));
- // Phase 1
+ // PHASE 0
+ auto pass_list = user_lower_phase0;
+
+ // PHASE 1
+ if (for_te_schedule) {
+ pass_list.push_back(tir::transform::InjectPrefetch());
+ pass_list.push_back(tir::transform::StorageFlatten(64,
instrument_bound_checkers));
+ } else {
+ pass_list.push_back(tir::transform::LowerInitBlock());
+
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+ pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+ pass_list.push_back(tir::transform::CompactBufferAllocation());
+ pass_list.push_back(tir::transform::FlattenBuffer());
+ }
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
- pass_list.push_back(tir::transform::LoopPartition());
+
+ // Add user-defined phase-1 passes
+ pass_list.insert(pass_list.end(), user_lower_phase1.begin(),
user_lower_phase1.end());
+
+ // PHASE 2
+ if (enable_loop_partition) {
+ pass_list.push_back(tir::transform::LoopPartition());
+ }
+
pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(tir::transform::UnrollLoop());
- // Phase 2
+
+ // Add user-defined phase-2 passes
+ pass_list.insert(pass_list.end(), user_lower_phase2.begin(),
user_lower_phase2.end());
+
+ // PHASE 3
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RemoveNoOp());
pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+ // HoistIfThenElse
+ pass_list.push_back(tir::transform::HoistIfThenElse());
+
+ // Add user-defined phase-3 passes
+ pass_list.insert(pass_list.end(), user_lower_phase3.begin(),
user_lower_phase3.end());
+
if (instrument_bound_checkers) {
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
}
- // run
+ return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass>
pass_list) {
auto optimize = transform::Sequential(pass_list);
mod = optimize(std::move(mod));
return mod;
}
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args,
const std::string& name,
+ const std::unordered_map<te::Tensor, tir::Buffer>&
binds) {
+ // Convert te schedule to IRModule
+ Array<ObjectRef> out_arg_list;
+ auto pass_ctx = transform::PassContext::Current();
+
+ sch = sch.normalize();
+
+ // Before TIR transformation.
+ auto bounds = te::InferBound(sch);
+ auto stmt = te::ScheduleOps(sch, bounds, false);
+ bool compact = te::VerifyCompactBuffer(stmt);
+
+ Map<te::Tensor, tir::Buffer> out_binds;
+ GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+ // Build the function
+ // At this point binds is only te::Tensors
+ stmt = te::SchedulePostProcRewriteForTensorCore(
+ stmt, sch,
+ binds); // TODO(electriclilies): Should this be in here? Was in python
but not C++ version.
Review comment:
What is this doing and should it be here? It was in the C++ version but
not python version.
##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
- Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool
for_te_schedule) {
auto pass_ctx = transform::PassContext::Current();
- sch = sch.normalize();
-
- // Before TIR transformation.
- auto bounds = te::InferBound(sch);
- auto stmt = te::ScheduleOps(sch, bounds, false);
- bool compact = te::VerifyCompactBuffer(stmt);
-
- Map<te::Tensor, tir::Buffer> out_binds;
- GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
- // build the function
- tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
- f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
- bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers",
Bool(false)).value();
- if (noalias) {
- f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+ // Get any user-added passes
+ auto add_lower_pass =
+ pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass",
Array<Array<ObjectRef>>())
+ .value();
+
+ auto user_lower_phase0 = Array<tvm::transform::Pass>();
+ auto user_lower_phase1 = Array<tvm::transform::Pass>();
+ auto user_lower_phase2 = Array<tvm::transform::Pass>();
+ auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+ // phase pasees is of the form
+ // [[phase_number, pass], [phase_number, pass]... ]
+ for (auto phase_pass : add_lower_pass) {
+ auto phase_num = phase_pass[0].as<IntImmNode>();
+ ICHECK(phase_num)
+ << "Expected the first entry in the inner Array of tir.add_lower_pass
to be an integer";
+ int phase_num_val = phase_num->value;
+
+ CHECK_GT(phase_num_val, 0);
+
+ auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+ auto pass = GetRef<tvm::transform::Pass>(pass_node);
+ // Copy the pass into the correct phase
+ if (phase_num_val == 0) {
+ user_lower_phase0.push_back(pass);
+ } else if (phase_num_val == 1) {
+ user_lower_phase1.push_back(pass);
+ } else if (phase_num_val == 2) {
+ user_lower_phase2.push_back(pass);
+ } else if (phase_num_val >= 3) {
+ user_lower_phase3.push_back(pass);
+ }
}
- auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
- auto pass_list = Array<tvm::transform::Pass>();
+ // Construct the pass list, inserting the user provided passes at the end of
the phase
+ // TODO(electriclilies): I'm not sure if they should go at the beginning or
the end of the phase.
Review comment:
What is the significance of phases? where are the boundaries between
phases? This was inconsistent across the C++ and python version.. For now I
have copied the python behavior exactly but I'm not sure if that is correct.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]