Lunderberg commented on a change in pull request #9263:
URL: https://github.com/apache/tvm/pull/9263#discussion_r729841052
##########
File path: src/driver/driver_api.cc
##########
@@ -334,6 +324,42 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
return mod;
});
+tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array<ObjectRef>&
args,
+ const std::string& name,
+ const std::unordered_map<te::Tensor,
tir::Buffer>& binds) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ bool debug_keep_trivial_loop =
+ pass_ctx->GetConfig<Bool>("tir.debug_keep_trivial_loop",
Bool(false)).value();
+
Review comment:
This tells `ScheduleOps` to keep loops that have an extent of 1, where
the default behavior is to replace trivial loops with a `Let` statement.
([impl](https://github.com/apache/tvm/blob/main/src/te/operation/compute_op.cc#L365))
The only place that it is used is in `lower_ethosu`, to maintain the
[previous
behavior](https://github.com/apache/tvm/blob/main/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py#L28)
of keeping the trivial loops. That said, I haven't looked into why the
trivial loops are kept in that case.
##########
File path: src/driver/driver_api.cc
##########
@@ -334,6 +324,42 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
return mod;
});
+tir::PrimFunc ScheduleToPrimFunc(te::Schedule sch, const Array<ObjectRef>&
args,
+ const std::string& name,
+ const std::unordered_map<te::Tensor,
tir::Buffer>& binds) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ bool debug_keep_trivial_loop =
+ pass_ctx->GetConfig<Bool>("tir.debug_keep_trivial_loop",
Bool(false)).value();
+
Review comment:
This tells `ScheduleOps` to keep loops that have an extent of 1, where
the default behavior is to replace trivial loops with a `Let` statement.
([impl](https://github.com/apache/tvm/blob/main/src/te/operation/compute_op.cc#L365))
The only place that it is used is in `lower_ethosu` to maintain the [previous
behavior](https://github.com/apache/tvm/blob/main/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py#L28)
of keeping the trivial loops. That said, I haven't looked into why the
trivial loops are kept in that case.
Edit: And also used in the autotvm feature extraction
##########
File path: python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
##########
@@ -64,22 +64,18 @@ def lower_ethosu(sch, args, const_dict, name="main"):
"no_unroll_loop_with_extent_one": True,
},
"tir.UnrollLoop": {"auto_max_depth": -1},
+ "tir.debug_keep_trivial_loop": True,
}
# Merge two configs
curr_cfg = {**curr_cfg, **tir_compiler_cfg}
sch = sch.normalize()
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True)
- compact = tvm.te.schedule.VerifyCompactBuffer(stmt)
- binds, arg_list = get_binds(args, compact, None)
- func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
-
- func = func.with_attr("global_symbol", name)
- func = func.with_attr("tir.noalias", True)
- mod = tvm.IRModule({name: func})
with tvm.transform.PassContext(config=curr_cfg):
+ func = schedule_to_primfunc(sch, args, name)
+ func = func.with_attr("tir.noalias", True)
+ mod = tvm.IRModule({name: func})
+
Review comment:
Due to silliness and lack of pattern recognition on my part, since I was
only looking for cases that could be replaced with `schedule_to_primfunc` at
that point. Changing it to `schedule_to_module`, and thank you for catching it!
##########
File path: python/tvm/autotvm/feature.py
##########
@@ -31,20 +31,18 @@
import tvm._ffi
from tvm.target import Target
-from tvm.te import schedule
from tvm.driver import build_module
def ana_lower(sch, args, binds=None, simple_mode=True):
"""Do lower while keeping all axes in IR
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or
inject virtual threads
"""
- binds, _ = build_module.get_binds(args, compact=False, binds=binds)
sch = sch.normalize()
# Phase 0
- bounds = schedule.InferBound(sch)
- stmt = schedule.ScheduleOps(sch, bounds, True)
- func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
+ context = tvm.transform.PassContext(config={"tir.debug_keep_trivial_loop":
True})
Review comment:
This deliberately keeps the loop iterators in-place, even if they have
`extent=1`, rather than the default behavior of replacing trivial iterators
with a Let statement. As a result, the itervars can be examined for
optimization parameters (e.g. in
[xgboost](https://github.com/apache/tvm/blob/main/python/tvm/autotvm/tuner/xgboost_cost_model.py#L361)).
Longer term, I'd prefer having it always generate the loops with a lowering
pass to identify/simplify the trivial loops, but that's a later item.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]