mbs-octoml commented on a change in pull request #10124:
URL: https://github.com/apache/tvm/pull/10124#discussion_r796905096



##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -489,51 +489,71 @@ class DeviceAnalyzer : public ExprVisitor {
 
   void VisitExpr_(const CallNode* call_node) final {
     auto call = GetRef<Call>(call_node);
-
-    // We don't care if the call is in pre- or post-lowered form.
-    auto vanilla_call = GetAnyCall(call_node);
-
-    // Find the higher-order domain for the callee. See DomainForCallee for 
the special rules
-    // for primitives.
-    VisitExpr(vanilla_call->op);
-    auto func_domain = domains_->DomainForCallee(call);  // higher-order
-
-    // Build the domain for the function implied by its arguments and call 
context.
-    ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()) << 
PrettyPrint(call);
-    std::vector<DeviceDomainPtr> args_and_result_domains;
-    args_and_result_domains.reserve(vanilla_call->args.size() + 1);
-    for (const auto& arg : vanilla_call->args) {
-      args_and_result_domains.emplace_back(domains_->DomainFor(arg));
-      VisitExpr(arg);
-    }
-    args_and_result_domains.emplace_back(domains_->DomainFor(call));
-    auto implied_domain =
-        domains_->MakeHigherOrderDomain(std::move(args_and_result_domains));  
// higher-order
-
-    VLOG(2) << "initial call function domain:" << std::endl
-            << domains_->ToString(func_domain) << std::endl
-            << "and implied domain:" << std::endl
-            << domains_->ToString(implied_domain) << std::endl
-            << "for call:" << std::endl
-            << PrettyPrint(call);
-
-    // The above must match.
-    if (domains_->UnifyOrNull(func_domain, implied_domain) == nullptr) {  // 
higher-order
-      // TODO(mbs): Proper diagnostics.
-      LOG(FATAL)
-          << "Function parameters and result VirtualDevices do not match those 
of call. Call:"
-          << std::endl
-          << PrettyPrint(call) << std::endl
-          << "with function virtual devices:" << std::endl
-          << domains_->ToString(func_domain) << std::endl
-          << "and implied call virtual devices:" << std::endl
-          << domains_->ToString(implied_domain);
+    std::stack<Expr> stack;

Review comment:
       I wrote this as a "naive" ExprVisitor and not as a MixedModeVisitor 
since I wrongly assumed it would only see post-ANF graphs. Hence the iterative 
rather than recursive handling of LetNodes but arbitrary recursion everywhere 
else. So sorry this is causing you trouble!
   
   Since it is just setting up domain constraints the order of visiting 
shouldn't matter much (at least for valid Relay graphs) and there's no tricky 
state management inside the visitor. So I think this could be converted to a 
MixedModeVisitor pretty easily. Would you be comfortable trying to do that? 
Happy to help review that, offer help, etc.
   
   

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -489,51 +489,71 @@ class DeviceAnalyzer : public ExprVisitor {
 
   void VisitExpr_(const CallNode* call_node) final {
     auto call = GetRef<Call>(call_node);
-
-    // We don't care if the call is in pre- or post-lowered form.
-    auto vanilla_call = GetAnyCall(call_node);
-
-    // Find the higher-order domain for the callee. See DomainForCallee for 
the special rules

Review comment:
       Can you please preserve the comments? Thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to