comaniac commented on a change in pull request #5493:
URL: https://github.com/apache/incubator-tvm/pull/5493#discussion_r419586113



##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -124,37 +112,40 @@ class AnnotationChecker : public ExprVisitor {
  *         the compiler name.
  */
 
-class Partitioner : public ExprMutator {
+class Partitioner : public MixedModeMutator {
  public:
   explicit Partitioner(const IRModule& module) : module_(module) {
     for (auto f : module->functions) {
       GlobalVar f_var = f.first;
       BaseFunc f_func = f.second;
 
-      // Creating regionset per function in the module
+      // Creating regionset per function in the module.
       auto region_set = AnnotatedRegionSet::Create(f_func, 
partitioning::compiler_begin_op,
                                                    
partitioning::compiler_end_op);
       regions_sets_[region_set] = f_func;
+
+      // Initial region function metadata.
+      for (auto region : region_set) {
+        region_func_meta_[region];
+      }
     }
   }
 
-  Expr VisitExpr_(const CallNode* call) final {
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
     auto op_node = call->op.as<OpNode>();
     if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
-      return ExprMutator::VisitExpr_(call);
+      return post;
     } else if (call->op == compiler_begin_op) {
-      // The annotation node is inserted on edge so it must have only one
-      // argument.
+      // The annotation node is inserted on edge so it must have only one 
argument.
       CHECK_EQ(call->args.size(), 1U);
 
       // Traverse the rest graph.
       Expr parent = call->args[0];
-      auto input_expr = VisitExpr(parent);
+      auto input_expr = Downcast<Call>(post)->args[0];
 
       // Backtrace the parent to find the first ancestor node that is not a 
begin or end op
       while (const auto* parent_call = parent.as<CallNode>()) {
-        if (parent_call->op == compiler_begin_op ||
-            parent_call->op == compiler_end_op) {
+        if (parent_call->op == compiler_begin_op || parent_call->op == 
compiler_end_op) {

Review comment:
       It's possible that this will work as you said, but we just didn't try 
it. The reasons are 1) it seems no benefit of using `post` instead of `pre`, 
and 2) we just interpreted the original logic to refactor the pass to 
non-recursive manner, so we can make sure it will work.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to