mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759587798
##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
return mod;
}
+ void AddExterns(IRModule module) {
+ // Everything tagged with "Compiler" has been compiled, so remove those
definitions.
+ std::vector<GlobalVar> to_be_deleted;
+ for (const auto& kv : module->functions) {
+ if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+ to_be_deleted.push_back(kv.first);
+ }
+ }
+ for (const auto& global_var : to_be_deleted) {
+ module->Remove(global_var);
+ }
+ // HOWEVER we still need a Relay definition to go with those now external
functions, so
+ // retrieve them from the cache and mark them with "ExternalSymbol".
+ for (const auto& kv1 : cache_) {
+ auto src_func = kv1.first->source_func;
+ ICHECK(src_func.defined());
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+ if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+ // Abandon the existing function annotations.
+ Function function(function_node->params, function_node->body,
function_node->ret_type,
+ function_node->type_params, /*attrs=*/{},
function_node->span);
+ // Mark function as 'extern' using the "ExternalSymbol" attribute.
+ function = WithAttr(std::move(function), attr::kExternalSymbol,
kv2.first->name_hint);
+ module->Add(kv2.first, function);
+ }
+ }
+ }
+ }
+ }
+
Array<tvm::runtime::Module> LowerExternalFunctions() {
Array<tvm::runtime::Module> ret;
- std::unordered_map<std::string, std::string> cached_symbol;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
ICHECK(src_func.defined());
- if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
- auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
- std::string code_gen_name = code_gen.value();
+ Optional<String> opt_compiler =
src_func->GetAttr<String>(attr::kCompiler);
+ if (opt_compiler.defined()) {
+ Optional<String> opt_symbol_name =
src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:"
<< std::endl
+ << PrettyPrint(src_func);
+ VLOG(1) << "using external codegen '" << opt_compiler.value() << "'
for name '"
+ << opt_symbol_name.value() << "' and function:" << std::endl
+ << PrettyPrint(src_func);
cached_ext_funcs.push_back(it.first);
- auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
- << AsText(src_func, false);
-
- std::string sn = symbol_name.value();
- if (cached_symbol.count(sn)) {
- cached_symbol[sn] = code_gen_name;
- } else {
- ICHECK_NE(sn, code_gen_name)
- << "Found duplicated symbol: " << sn << " for: " <<
code_gen_name;
- }
-
- std::string ext_name = "relay.ext." + code_gen_name;
+ std::string ext_name = "relay.ext." + opt_compiler.value();
auto pf = tvm::runtime::Registry::Get(ext_name);
ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
// No need to keep compiler attribute at this point, functions have
been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler,
NullValue<ObjectRef>());
+ VLOG_CONTEXT << ext_name;
runtime::Module ext_mod = (*pf)(src_func);
-
- ICHECK(ext_mod.defined()) << "No external runtime is generated.";
- ret.push_back(ext_mod);
+ if (ext_mod.defined()) {
+ if (ext_mod->GetFunction(opt_symbol_name.value(),
/*query_imports=*/true) == nullptr) {
+ // It's possible the codegen yielded C or C++ tracked separately
and thus the
+ // returned runtime module can be empty.
Review comment:
I asked our ARM friends about this since it was also bugging me. It
turns out the runtime::Module will be an EthosUModule which does not support
GetFunction for the compiled functions, but *does* support GetFunction for the
meta 'get_func_names' which would yield a list of those names. However that's
not a standard interface te_compiler can depend on. Better would be HasFunction
or something on runtime::Module but I'm leaving that for a rain(ier) day.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]