manupa-arm commented on a change in pull request #5616:
URL: https://github.com/apache/incubator-tvm/pull/5616#discussion_r427156985



##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -404,18 +404,85 @@ IRModule RemoveDefaultAnnotations(IRModule module) {
   return module;
 }
 
+/*! \brief There can be regions with multiple outputs where each output
+ *  could be a tuple output. Such tuple outputs needs to be flattened
+ *  otherwise the function would create tuples of tuples.
+ */
+
+// New annotations would be required to be added for each flattened output
+const PackedFunc* make_end_op = 
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+
+IRModule FlattenTupleOutputs(IRModule module) {
+  class TupleOutFlattener : public ExprRewriter {
+   public:
+    TupleOutFlattener() = default;
+
+    Expr InsertAnnotation(const Expr& expr, const std::string& target, const 
PackedFunc* ann_op) {
+      Expr new_op = (*ann_op)(expr, target);
+      new_op->checked_type_ = expr->checked_type_;
+      return new_op;
+    }
+
+    Expr Rewrite_(const CallNode* call, const Expr& post) final {
+      if (call->op == compiler_end_op) {
+        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+        // Arguments of annotation ops should be 1
+        CHECK_EQ(call->args.size(), 1U);
+        auto annotated_op = Downcast<Call>(post)->args[0];
+        if (annotated_op->IsInstance<TupleNode>()) {
+          auto tn = annotated_op.as<TupleNode>();
+          Array<Expr> new_fields;
+
+          // Here each input of the tuple will be annotated with compiler_ends
+          for (auto& tn_arg : tn->fields) {
+            auto nf = InsertAnnotation(tn_arg, target, make_end_op);
+            new_fields.push_back(nf);
+          }
+
+          // Return a tuple of compiler_ends in the place of the tuple that was
+          // annotated with a compiler_end.
+          auto out = Tuple(new_fields);
+          return std::move(out);
+        }
+      }
+      return post;
+    }
+  };
+
+  auto glob_funcs = module->functions;
+  // module is mutable, hence, we make a copy of it.
+  module.CopyOnWrite();
+  for (const auto& pair : glob_funcs) {
+    if (auto* fn = pair.second.as<FunctionNode>()) {
+      auto func = GetRef<Function>(fn);
+      TupleOutFlattener to_flattener;
+      auto removed = PostOrderRewrite(func->body, &to_flattener);
+      func = Function(func->params, removed, func->ret_type, 
func->type_params, func->attrs);
+      module->Update(pair.first, func);
+    }
+  }
+  return module;
+}
+
 }  // namespace partitioning
 
 namespace transform {
 
 Pass PartitionGraph() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = 
[=](IRModule m,
                                                                             
PassContext pc) {
+    // There could be compiler_end annotations on tuples
+    // If the corresponding region is having multiple compiler_ends,
+    // this would lead to creation of tuples of tuples.
+    // Thus, we flatten the tuples by transfering the compiler_end to
+    // the tuple inputs.
+    auto _m = partitioning::FlattenTupleOutputs(m);

Review comment:
       > Thanks for bring this issue as well as the PR. IIUC, this new pass 
kicks the output tuple out of the region to make sure we will not have a nested 
tuple output.
   > 
   > Before:
   > 
   > ```
   > a   b   c
   > |   |   |
   > --tuple--
   >     |
   >    end
   > ```
   > 
   > After:
   > 
   > ```
   > a       b      c
   > |       |      |
   > end    end    end
   > ------tuple------
   > ```
   > 
   > In terms of the implementation, I am concerned if we want to introduce 
another pass to flat tuples instead of implementing this behavior to one of the 
existing passes. For example,
   > 
   >     1. In `MergeCompilerRegion`, when you see multiple outputs, apply this 
logic to kick the tuple out of the region.
   > 
   >     2. In `PartitionGraph`, when you see multiple outputs, traverse all 
outputs, make a flatten tuple, and create corresponding `TupleGetItem`.
   > 
   > 
   > I personally prefer the second approach because we cannot guarantee users 
will run `MergeCompilerRegion`.
   
   @comaniac thanks for explaining what I have done (graphically!) . I also 
agree that this logic should be in the partition graph rather than 
MergeCompilerRegions pass. However, I do not think we should tightly embed the 
flattening logic inside the function creation for several reasons as follows :
   1) The code would be unnecessarily complex : we would need to investigate 
all dataflows that could possible source to a tuple inside a annotate region 
set which have multiple outputs, because that tuple is to be created again 
after the boundary the partitioned function. Functionality wise, this sound 
seperate to "partitioning" functions from the graph. IMO, any transformation 
that has to be done within and outside the boundary of annotated region set is 
going to be bulky (in terms of code --> as we have to handle all possible 
consumers).
   2) In future if we decide to support nested tuples in runtime 
passes/components, we might want to remove this internal pass. (that is a 
scretched one, I know! :) )
   3) IMHO, the clarity in the areas where the outputs of the partitioned 
function would adversely affected if we glued this logic in there.
   
   Therefore, I believe its cleaner as of now to keep this pass a seperate but 
an internal pass to PartitionGraph. In future, if the need arise we can fuse it 
to the PartitionGraph pass tightly (Though I cannot think why it should benefit 
more now.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to