jroesch commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r638402304
##########
File path: src/driver/driver_api.cc
##########
@@ -185,6 +170,59 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>&
args, const std::strin
return mod;
}
+IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const
std::string& name,
+ const std::unordered_map<te::Tensor, tir::Buffer>& binds, bool
simple_mode) {
+ // Convert te schedule to IRModule
+ Array<ObjectRef> out_arg_list;
+ auto pass_ctx = transform::PassContext::Current();
+
+ sch = sch.normalize();
+
+ // Before TIR transformation.
+ auto bounds = te::InferBound(sch);
+ auto stmt = te::ScheduleOps(sch, bounds, false);
+ bool compact = te::VerifyCompactBuffer(stmt);
+
+ Map<te::Tensor, tir::Buffer> out_binds;
+ GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+ // build the function
+ tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
+ f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+ bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+ if (noalias) {
+ f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+ }
+ IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+ return lower(mod, args, name, binds, simple_mode);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower")
+ .set_body_typed([](ObjectRef obj, const Array<te::Tensor>& args, const
String& name,
+ const Map<te::Tensor, tir::Buffer>& binds, bool
simple_mode) {
+ std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+ // Check to make sure binds is not null before doing the conversion;
+ if (binds.get() != NULL) {
+ for (auto kv : binds) {
+ c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first,
kv.second));
+ }
+ }
+
+ if (const auto* p_mod = obj.as<IRModuleNode>()) {
+ IRModule mod = GetRef<IRModule>(p_mod);
+ return lower(mod, args, name, c_binds, simple_mode);
+ } else if (const auto* p_sch = obj.as<te::ScheduleNode>()) {
+ te::Schedule sch = GetRef<te::Schedule>(p_sch);
+ return lower(sch, args, name, c_binds, simple_mode);
+ } else {
+ ICHECK(false) << "driver.lower expects the first argument to be a
te::Schedule or "
+ << "IRModule";
+ throw;
Review comment:
Don't add a throw here, instead return an empty IRModule(...)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]