tqchen commented on a change in pull request #8044:
URL: https://github.com/apache/tvm/pull/8044#discussion_r632955612
##########
File path: python/tvm/driver/build_module.py
##########
@@ -160,16 +173,38 @@ def lower(sch, args, name="main", binds=None,
simple_mode=False):
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
- if isinstance(sch, schedule.Schedule):
- mod = form_irmodule(sch, args, name, binds)
+ pass_list = lower_phase0
+ is_legacy_te_schedule: bool = False
+
+ if isinstance(inputs, schedule.Schedule):
+ if args is None:
+ raise ValueError("args must be given for lowering from TE
schedule")
+ mod = form_irmodule(inputs, args, name, binds)
+ is_legacy_te_schedule = True
+ elif isinstance(inputs, PrimFunc):
+ func = inputs.with_attr("global_symbol", name)
+ if pass_ctx.config.get("tir.noalias", True):
+ func = func.with_attr("tir.noalias", True)
+ mod = tvm.IRModule({name: func})
+ elif isinstance(inputs, IRModule):
+ mod = inputs
else:
- mod = sch
+ raise TypeError(
+ f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got
{type(inputs)}"
+ )
- pass_list = lower_phase0
# Phase 1
+ if is_legacy_te_schedule:
+ pass_list += [
+ tvm.tir.transform.InjectPrefetch(),
+ tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
+ ]
pass_list += [
- tvm.tir.transform.InjectPrefetch(),
- tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
+ tvm.tir.transform.LowerInitBlock(),
+ tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
+ tvm.tir.transform.ConvertBlocksToOpaque(),
+ tvm.tir.transform.CompactBufferAllocation(),
+ tvm.tir.transform.FlattenBuffer(),
Review comment:
we do not, i agree it is helpful to put new tir schedule specific passes
to an else block for now
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]