mbs-octoml commented on a change in pull request #10352:
URL: https://github.com/apache/tvm/pull/10352#discussion_r813204830



##########
File path: src/printer/relay_text_printer.cc
##########
@@ -220,9 +220,13 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
   }
   Doc val = GetUniqueName("%" + name);
   memo_[var] = val;
+  if (!var->virtual_device()->IsFullyUnconstrained()) {

Review comment:
       Any issue with this being used for both param- and let-bound vars even 
though we don't parse annots for let-bound vars?

##########
File path: src/relay/ir/expr_functor.cc
##########
@@ -476,41 +476,31 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& 
args_map) {
   if (const FunctionNode* func = expr.as<FunctionNode>()) {
     Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
     Array<Var> new_params;
-    std::vector<VirtualDevice> new_param_virtual_devices;
+    bool params_unchanged = true;
     for (size_t i = 0; i < func->params.size(); ++i) {
       if (!args_map.count(func->params[i])) {
         new_params.push_back(func->params[i]);
-        
new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i));
+      } else if (const auto var = args_map[func->params[i]].as<VarNode>()) {
+        // If we're mapping a variable to a variable and not a normal expr, 
then we want to
+        // put the substitution in the new parameters.
+        params_unchanged = false;
+        new_params.push_back(GetRef<Var>(var));
       }
     }
-    if (new_body.same_as(func->body) && new_params.size() == 
func->params.size()) {
+    if (new_body.same_as(func->body) && new_params.size() == 
func->params.size() &&
+        params_unchanged) {
       return expr;
     }
     auto ret =
         Function(new_params, new_body, func->ret_type, func->type_params, 
func->attrs, func->span);
-    ret =
-        MaybeFunctionOnDevice(ret, new_param_virtual_devices, 
GetFunctionResultVirtualDevice(func));
-    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set;

Review comment:
       Bind should really be called SubstAndBind, since it assumes any free 
vars introduced in the subst rhs should be implicitly bound in the result 
function. I think you'll need to keep this semantics.
   
   For your purposes I think just directly using the ExprBinder is what you 
want since you're simply wanting the Subst aspect. I forgot about this hackery 
when I pointed you to this func.

##########
File path: src/printer/relay_text_printer.cc
##########
@@ -329,7 +333,10 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool 
meta, bool try_inline, bo
 
 // Should only be triggered when op is a free variable being visited for the
 // first time.
-Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return 
AllocVar(GetRef<Var>(op)); }
+Doc RelayTextPrinter::VisitExpr_(const VarNode* op) {

Review comment:
       nit: guess you can revert this now.

##########
File path: src/parser/parser.cc
##########
@@ -1107,11 +1111,26 @@ class Parser {
           [&]() {
             auto token = Match(TokenType::kLocal);
             auto string = token.ToString();
+
+            // The fake attributes where the virtual device is specified.
+            VirtualDevice virtual_device;
+            if (WhenMatch(TokenType::kLCurly)) {
+              Map<String, ObjectRef> fake_attrs = ParseAttrs();
+              VLOG(1) << "Fake attributes for function parameter: " << 
fake_attrs;

Review comment:
       meganit: I've been using VLOG(9) for these super-duper verbose ones 
since I often just set TVM_LOG_DEBUG="DEFAULT=1" to get an overall debug trace 
that's still vaguely readable.

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -252,21 +252,16 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
       // Do that flattening on-the-fly here.
       Function inner_func = Downcast<Function>(func->body);
       std::vector<Var> params;
-      std::vector<VirtualDevice> param_virtual_devices;
       params.reserve(func->params.size() + inner_func->params.size());
-      param_virtual_devices.reserve(func->params.size() + 
inner_func->params.size());
       param_device_indexes.reserve(func->params.size() + 
inner_func->params.size());
       for (size_t i = 0; i < func->params.size(); ++i) {
         params.emplace_back(func->params[i]);
-        VirtualDevice param_virtual_device = 
GetFunctionParamVirtualDevice(func.get(), i);
-        param_virtual_devices.push_back(param_virtual_device);
-        param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
+        
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));

Review comment:
       we get some payoff at last!

##########
File path: src/relay/backend/vm/lambda_lift.cc
##########
@@ -111,22 +111,21 @@ class LambdaLifter : public 
transform::DeviceAwareExprMutator {
     auto free_type_vars = FreeTypeVars(func, module_);
 
     Array<Var> captured_vars;
-    std::vector<VirtualDevice> captured_var_virtual_devices;
     bool recursive = false;
     for (const auto& var : free_vars) {
       if (!letrec_.empty() && var == letrec_.back()) {
         recursive = true;
         continue;
       }
       captured_vars.push_back(var);
-      captured_var_virtual_devices.push_back(GetVirtualDevice(var));

Review comment:
       These vars may be let-bound as well as param-bound, so I don't think 
this is sound.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to