jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758751452



##########
File path: src/relay/transforms/memory_alloc.cc
##########
@@ -311,48 +292,58 @@ class DialectRewriter : public 
transform::DeviceAwareExprMutator {
           shape_func_ins.push_back(scope->Push(in_shape_var, sh_of));
           input_pos++;
         }
-        is_inputs.push_back(0);
       } else if (state == tec::kNeedInputData) {
         auto new_arg = Mutate(arg);  // already accounts for device
         SEScope arg_se_scope = GetSEScope(arg);
+        // The dynamic shape function is expecting its data on the host/CPU, 
so insert a
+        // device_copy otherwise. (We'll need to fuse & lower these copies in 
the same way
+        // we fuse & lower other operators we insert for, eg, dynamic tensor 
size calculation.)
         if (arg_se_scope != host_se_scope_) {
           new_arg = OnDevice(DeviceCopy(new_arg, arg_se_scope, 
host_se_scope_), host_se_scope_,
                              /*is_fixed=*/true);
         }
         Var in_shape_var("in_shape_" + std::to_string(input_pos), 
Type(nullptr));
         shape_func_ins.push_back(scope->Push(in_shape_var, new_arg));
         input_pos++;
-        is_inputs.push_back(1);
       } else {
         // TODO(@jroesch): handle kNeedBoth
         LOG(FATAL) << "unsupported shape function input state";
       }
     }
+    ICHECK_EQ(shape_func_ins.size(), func_type_node->arg_types.size());
+
+    // Establish the result shapes.
+    const auto* res_tuple_node = func_type_node->ret_type.as<TupleTypeNode>();
+    ICHECK(res_tuple_node);
 
     Array<Expr> out_shapes;
-    for (size_t i = 0; i < cfunc->outputs.size(); ++i) {
-      auto out = cfunc->outputs[i];
-      auto tt = TensorType(out->shape, out->dtype);
-      // Put shape func on CPU. This also ensures that everything between
-      // shape_of and shape_func are on CPU.
-      auto alloc = OnDevice(MakeStaticAllocation(scope, tt, host_se_scope_, 
std::to_string(i)),
-                            host_se_scope_, /*is_fixed=*/true);
+    for (size_t i = 0; i < res_tuple_node->fields.size(); ++i) {
+      const auto* tensor_type_node = 
res_tuple_node->fields[i].as<TensorTypeNode>();
+      ICHECK(tensor_type_node);
+      // Put the shape func on the host. This also ensures that everything 
between
+      // shape_of and shape_func is similarly on the host.
+      Expr alloc = MakeStaticAllocation(scope, 
GetRef<TensorType>(tensor_type_node), host_se_scope_,
+                                        std::to_string(i));
+      // TODO(mbs): Why extra var binding?

Review comment:
       I think the original pass was incrementally re-anf-ing the code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to