jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758750942



##########
File path: src/relay/transforms/memory_alloc.cc
##########
@@ -122,64 +107,81 @@ class DialectRewriter : public 
transform::DeviceAwareExprMutator {
     return ret;
   }
 
-  Expr DeviceAwareVisitExpr_(const CallNode* cn) final {
-    Call call = GetRef<Call>(cn);
+  Expr DeviceAwareVisitExpr_(const CallNode* call_node) final {
+    DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+
+    if (device_copy_props.body.defined()) {
+      // Special case: device_copy calls remain in their original (and 
functional) form.
+      // TODO(mbs): device_copy cleanup.
+      return 
transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
+    }
+
+    if (!call_lowered_props.lowered_func.defined()) {
+      // This is a call to a user-defined Relay functinon, which will be 
handled directly by
+      // the VM and does not need conversion to DPS.
+      return 
transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
+    }
+
+    Call call = GetRef<Call>(call_node);
+    VLOG(1) << "converting lowered call to DPS:" << std::endl << 
PrettyPrint(call);
+
     SEScope se_scope = GetSEScope(call);
-    if (IsPrimitive(cn)) {
-      // Because we are in ANF we do not need to visit the arguments.
-      // TODO(mbs): But does so anyway...
-      LetList& scope = scopes_.back();
-      std::vector<Expr> new_args;
-      for (const auto& it : cn->args) {
-        new_args.push_back(Mutate(it));
-      }
+    LetList& scope = scopes_.back();
 
-      Tuple ins(new_args);
-      Type ret_type = cn->checked_type_;
-      std::vector<TensorType> out_types = FlattenTupleType(ret_type);
+    std::vector<Expr> new_args;
+    for (const auto& arg : call_lowered_props.arguments) {
+      new_args.push_back(Mutate(arg));
+    }
+    Tuple ins(new_args);
+    Type ret_type = call_node->checked_type_;
+    std::vector<TensorType> out_types = FlattenTupleType(ret_type);
+
+    // Handle reshape.

Review comment:
       Nevermind I see now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to