jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609055722
##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -731,50 +170,50 @@ class CompileEngineImpl : public CompileEngineNode {
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
- auto cache_node = make_object<CachedFuncNode>();
+ auto ir_module = IRModule();
const auto name_node =
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(name_node.defined()) << "External function has not been attached
a name yet.";
- cache_node->func_name = std::string(name_node.value());
- cache_node->target = Target("ext_dev");
- cache_node->funcs->Add(GlobalVar(cache_node->func_name),
key->source_func);
- value->cached_func = CachedFunc(cache_node);
+ auto func_name = std::string(name_node.value());
+ auto target = Target("ext_dev");
+ auto global_var = GlobalVar(func_name);
+ global_var->checked_type_ = key->source_func->checked_type();
+ ir_module->Add(global_var, key->source_func);
+ value->cached_func = CachedFunc(target, global_var, {}, {},
te::Schedule(), {}, ir_module);
return value;
}
+
// Enforce use the target.
With<Target> target_scope(key->target);
ICHECK(!value->cached_func.defined());
- auto cfunc = CreateSchedule(key->source_func, key->target);
- auto cache_node = make_object<CachedFuncNode>(*(cfunc.operator->()));
+ auto cfunc = PrimFuncFor(key->source_func, key->target,
+ [&](std::string name) { return
GetUniqueName(name, name_map_); });
// Skip lowering for device copy node.
const Expr body = (key->source_func)->body;
if (const CallNode* call_node = body.as<CallNode>()) {
if (call_node->attrs.as<DeviceCopyAttrs>()) {
- value->cached_func = CachedFunc(cache_node);
+ value->cached_func = cfunc;
return value;
}
}
- cache_node->func_name = GetUniqueName(cache_node->func_name);
// NOTE: array will copy on write.
- Array<te::Tensor> all_args = cache_node->inputs;
- for (te::Tensor arg : cache_node->outputs) {
+ Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+ for (te::Tensor arg : cfunc->outputs) {
all_args.push_back(arg);
}
- // lower the function
- if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
- cache_node->funcs = (*f)(cfunc->schedule, all_args,
cache_node->func_name, key->source_func);
- } else {
- using tvm::transform::PassContext;
- With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
- std::unordered_map<te::Tensor, tir::Buffer> binds;
- cache_node->funcs = tvm::lower(cfunc->schedule, all_args,
cache_node->func_name, binds);
- }
- value->cached_func = CachedFunc(cache_node);
+ using tvm::transform::PassContext;
+ With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+ std::unordered_map<te::Tensor, tir::Buffer> binds;
+ auto func_name = cfunc->prim_fn_var->name_hint;
+ cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name,
binds));
Review comment:
We need to fix this, this is BAD imo. Should we file a follow up issue?
I think we should remove the Python API and force it into C++.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]