electriclilies commented on a change in pull request #8886:
URL: https://github.com/apache/tvm/pull/8886#discussion_r701326664
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -85,33 +85,46 @@ class TECompilerImpl : public TECompilerNode {
return LowerShapeFuncInternal(key)->cached_func;
}
- Map<Target, IRModule> GetLoweredFunctions() {
- std::unordered_map<Target, IRModule, backend::TargetStrHash,
backend::TargetStrEqual>
- lowered_functions;
+ IRModule GetLoweredFunctions() {
+ IRModule mod;
+ // Extract lowered functions from the cache
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
- auto target = source_func->target;
- if (!lowered_functions.count(target)) {
- lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
- }
+ IRModule lowered_mod = lowered_func->cached_func->funcs;
- lowered_functions[target]->Update(lowered_func->cached_func->funcs);
- }
+ // Annotate functions with their target and put them in the return module
+ for (auto kv : lowered_mod->functions) {
+ const GlobalVar& var = kv.first;
+ const BaseFunc& func = kv.second;
+ // Only add functions that are not external functions
+ if (!func->GetAttr<String>(attr::kCompiler).defined()) {
+ ICHECK(func->IsInstance<tir::PrimFuncNode>())
+ << "Expected all functions that are not external to be
PrimFuncs, but found "
+ << func->GetTypeKey();
+ const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
+ mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget,
source_func->target));
+ }
+ }
+ }
+ // Extract lowered frunctions from the shape cache
Review comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]