mbs-octoml commented on a change in pull request #8788:
URL: https://github.com/apache/tvm/pull/8788#discussion_r720592924



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -518,54 +515,55 @@ class LowerTensorExprMutator : public ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const CallNode* call) override {
-    Call expr = GetRef<Call>(call);
+  Expr VisitExpr_(const CallNode* call_node) override {
+    Device device = GetOnDeviceDevice(call_node);
+    if (device.device_type) {
+      PushDevice(device);
+      Expr arg = VisitExpr(call_node->args[0]);
+      PopDevice();
+      return Call(call_node->op, {arg}, call_node->attrs, 
call_node->type_args, call_node->span);
+    }
+
+    Call expr = GetRef<Call>(call_node);
 
     // Look for (indirect) calls to primitives.
-    Function prim_func = ResolveToPrimitive(call->op);
+    Function prim_func = ResolveToPrimitive(call_node->op);
     if (!prim_func.defined()) {
-      // Not a call to a primitive function.
-      if (const FunctionNode* fn = call->op.as<FunctionNode>()) {
+      // Not a call_node to a primitive function.
+      if (const FunctionNode* fn = call_node->op.as<FunctionNode>()) {
         this->process_fn_(GetRef<Function>(fn));
       }
-      return ExprMutator::VisitExpr_(call);
+      return ExprMutator::VisitExpr_(call_node);
     }
 
     // Find the desired target device.
     Target target;
     if (prim_func->GetAttr<String>(attr::kCompiler).defined()) {
       // The generic 'external device' target.
       target = Target("ext_dev");
-    } else if (device_context_map_.empty() && targets_.size() == 1) {
-      // The unique target.
-      target = GetTargetFromInteger(kDLCPU, targets_);
     } else {
-      // The target corresponding to the call expression's annotation.
-      auto itr = device_context_map_.find(expr);
-      ICHECK(itr != device_context_map_.end())
-          << "Could not find an entry in the device context map for " << expr
-          << "The memory planning was either not performed for this precise 
node, or there is "
-             "bug in the memory planner.";
-      target = GetTargetFromInteger(itr->second.device_type, targets_);
+      // The target corresponding to the call_node expression's annotation.
+      device = GetDevice(expr);
+      // TODO(mbs): device id
+      target = GetTargetFromInteger(device.device_type, targets_);

Review comment:
       I think this one will go away with @mikepapadim 's #9134




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to