mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759565993



##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -540,65 +505,60 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
     }
 
     for (auto output : output_tuple->fields) {
+      ICHECK(output->IsInstance<VarNode>()) << "output should be var, found:" 
<< std::endl
+                                            << PrettyPrint(output);
       auto reg = var_register_map_.find(Downcast<Var>(output));
       ICHECK(reg != var_register_map_.end())
           << "internal error: all variables should be in the register mapping";
       argument_registers.push_back(reg->second);
     }
 
-    Target target;
-
-    // Which target should execute the function?
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
+    Index op_index;
+    auto itr = context_->primitive_map.find(global_var_node->name_hint);
+    if (itr == context_->primitive_map.end()) {
+      op_index = context_->primitive_map.size();
+      context_->primitive_map.emplace(global_var_node->name_hint, op_index);
     } else {
-      target = se_scope->target;
+      op_index = itr->second;
     }
-    ICHECK(target.defined()) << "No target for function:" << std::endl << 
PrettyPrint(func);
-
-    tec::CCacheKey key(func, target);
-    auto mangle_fn = [](String name) { return name; };
-    auto cfunc = context_->compiler->Lower(key, mangle_fn);  // <<<< 
one-func-at-a-time lowering
 
-    auto op_index = -1;
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      op_index = context_->cached_funcs.size();
-      context_->cached_funcs.push_back(cfunc);
-    } else {
-      // TODO(jroesch): support lowered funcs for multiple targets
-      ICHECK_EQ(cfunc->funcs->functions.size(), 1);
-      auto pfunc = 
Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
-      if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) {
-        op_index = context_->cached_funcs.size();
-        context_->cached_funcs.push_back(cfunc);
-        context_->seen_funcs[pfunc] = op_index;
-      } else {
-        op_index = context_->seen_funcs[pfunc];
-      }
-    }
-
-    // Extract functions attrs
-    op_attrs[op_index] = func->attrs->dict;
+    // Capture the dictionary of attributes from the original primitive 
function so that they

Review comment:
       I think that's a separate issue related to versioning of kernels etc. 
The change here is preserving the existing behaviour.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to