jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609055222
##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
return AddNode(node, GetRef<Expr>(op));
}
- std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
- Expr expr = GetRef<Expr>(op);
- Function func;
- if (op->op.as<OpNode>()) {
- LOG(FATAL) << "Operators should be transformed away; try applying"
- << "the fuse_ops transformation to the expression.";
- } else if (op->op.as<GlobalVarNode>()) {
- LOG(FATAL) << "Not implemented";
- } else if (op->op.as<FunctionNode>()) {
- func = GetRef<Function>(op->op.as<FunctionNode>());
- } else {
- LOG(FATAL) << "TVM runtime does not support calls to " <<
op->op->GetTypeKey();
- }
- if (!func->HasNonzeroAttr(attr::kPrimitive)) {
- LOG(FATAL) << "TVM only support calls to primitive functions "
- << "(i.e functions composed of fusable operator invocations)";
- }
+ std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+ relay::Call call = GetRef<Call>(call_node);
+ if (auto global_node = call->op.as<GlobalVarNode>()) {
- auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
- auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
- Target target;
- // Handle external function
- if (func->GetAttr<String>(attr::kCompiler).defined()) {
- target = Target("ext_dev");
- CCacheKey key = (*pf0)(func, target);
- CachedFunc ext_func = (*pf1)(compile_engine_, key);
- ICHECK(ext_func.defined()) << "External function is not defined.";
- UpdateConstants(func, ¶ms_);
- return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
- }
+ auto prim_fn_name = global_node->name_hint;
- ICHECK_GE(storage_device_map_.count(expr), 0);
- auto& device_type = storage_device_map_[expr][1];
- auto call_dev_type = device_type[0]->value;
- // Normal Relay Function
- if (targets_.size() == 1) {
- // homogeneous execution.
- const auto& it = targets_.begin();
- target = (*it).second;
- } else {
- // heterogeneous execution.
- std::string call_dev_name;
- if (call_dev_type == 0) {
- call_dev_name = "llvm";
+ Target target;
+
+ // // Handle external function
+ // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ // UpdateConstants(func, ¶ms_);
+ // return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
Review comment:
This code has been put inside the lowering the below case should catch
it now.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]