zhiics commented on a change in pull request #5143: [RELAY] Re-wrote the Graph 
Partitioner to support multiple outputs
URL: https://github.com/apache/incubator-tvm/pull/5143#discussion_r397979319
 
 

 ##########
 File path: src/relay/transforms/partition_graph.cc
 ##########
 @@ -165,101 +162,142 @@ class Partitioner : public ExprMutator {
       // Traverse the rest graph.
       auto input_expr = VisitExpr(call->args[0]);
 
-      // Replace the begin annotation with an external call input variable.
-      auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+      AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
+      int index = GetArgIdx(sg, GetRef<Call>(call));
+      CHECK_NE(index, -1);
       // The type of the created variable is the same as the compiler_begin
       // node.
-      auto var = Var(compiler_attrs->compiler + "_input" + 
std::to_string(var_id_++),
-                               call->checked_type_);
-
-      // Find the corresponding subgraph and add the argument.
-      auto subgraph = GetSubgraph(GetRef<Call>(call));
-      if (!subgraph) {
-        throw Error(ErrorBuilder()
-                    << "Cannot find the corresponding subgraph for start 
annotation:\n"
-                    << AsText(GetRef<Call>(call), false));
-      }
-      subgraph->args.push_back({var, input_expr});
+      std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+      std::string varname = target + "_" + std::to_string(sg->GetID())
+                            + "_i" + std::to_string(index);
+      auto var = Var(varname, GetRef<Call>(call)->checked_type_);
+
+      auto cand = std::make_pair(var, input_expr);
+      if (std::find(region_args[sg].begin(),
+                    region_args[sg].end(), cand) == region_args[sg].end()) {
+        region_args[sg].push_back(cand);
+     }
+
       return std::move(var);
     } else {
       CHECK_EQ(call->op, compiler_end_op);
       // The annotation node is inserted on edge so it must have only one 
argument.
       CHECK_EQ(call->args.size(), 1U);
 
-      auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+      AnnotatedRegion region = GetRegion(GetRef<Call>(call));
 
-      // Check if the argument already belongs to an existing subgraph
-      auto subgraph = GetSubgraph(call->args[0]);
-      if (!subgraph) {
-        auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
-        subgraph = *ret.first;
-        subgraph->nodes.insert(call->args[0]);
-        subgraph->id = this->subgraph_id_++;
-      }
-      subgraph->nodes.insert(GetRef<Call>(call));
+      // TODO(@manupa-arm) : need to use the parent function (to which region
+      // belongs to) name/key for the funtions that are created
+      BaseFunc f = GetFunc(GetRef<Call>(call));
 
       // Traverse subgraph inputs.
       auto input = VisitExpr(call->args[0]);
-      Array<Var> params;
-      Array<Expr> args;
-      std::unordered_map<std::string, runtime::NDArray> params_bind;
-
-      // The subgraph may be merged so we need to update it again.
-      subgraph = GetSubgraph(GetRef<Call>(call));
-      CHECK(subgraph);
-
-      // Record the constants for propagation.
-      for (auto pair : subgraph->args) {
-        params.push_back(pair.first);
-        if (const auto* cn = pair.second.as<ConstantNode>()) {
-          params_bind[pair.first->name_hint()] = cn->data;
+      CHECK(region.defined()) << "Region not defined for " << 
GetRef<Call>(call);
+      // functions are created for each annotated regions,
+      // when their first output is encountered.
+      // If multiple outputs are there, a tuple node is inserted at the end.
+      // region_function_calls is map that maintains
+      // (each annotated regions) --> created function
+
+      if (region_function_calls.find(region) != region_function_calls.end()) {
+      // This section is executed only if there are multiple outputs in the 
region
+      // Thus, the function is always created and at the end there would be a 
tuple node
+      // Therefore, we insert a tuple get item node.
+
+      // Use the already created tuple node
+        auto sg_call = region_function_calls[region];
+        int index = GetRetIdx(region, GetRef<Call>(call));
+        CHECK_NE(index, -1);
+
+        auto tuple_get_item_ = TupleGetItem(sg_call, index);
+        tuple_get_item_->checked_type_ = 
GetRef<Call>(call)->args[0]->checked_type_;
+        return std::move(tuple_get_item_);
+      } else {
+        // First time this region is encountered in the traversal
+        // Creating the function
+
+        Array<Expr> fields;
+
+        for (auto ret : region->GetOutputs()) {
+          auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
+          fields.push_back(ret_expr);
+        }
+        int index = GetRetIdx(region, GetRef<Call>(call));
+        CHECK_NE(index, -1);
+
+        Array<Var> params;
+        Array<Expr> param_expr;
+        std::unordered_map<std::string, runtime::NDArray> params_bind;
+
+        for (auto pair : region_args[region]) {
+          params.push_back(pair.first);
+          if (const auto* cn = pair.second.as<ConstantNode>()) {
+            params_bind[pair.first->name_hint()] = cn->data;
+          } else {
+            param_expr.push_back(pair.second);
+          }
+        }
+
+        Function global_region_func;
+        if (region->GetOutputs().size() == 1) {
+          // If there are only a single output; no need to add a tuple
+          global_region_func = Function(params, fields[0],
+                                  call->args[0]->checked_type_, {}, 
DictAttrs());
         } else {
-          args.push_back(pair.second);
+          auto tuple = Tuple(fields);
+          global_region_func = Function(params, tuple, tuple->checked_type_, 
{}, DictAttrs());
+        }
+
+        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+        std::string name = target + "_" + std::to_string(region->GetID());
+
+        global_region_func = WithAttr(std::move(global_region_func), 
attr::kExternalSymbol,
+                                      tir::StringImmNode::make(name));
+        global_region_func = WithAttr(std::move(global_region_func), 
attr::kPrimitive,
+                            tvm::Integer(1));
+        global_region_func = WithAttr(std::move(global_region_func), 
attr::kCompiler,
+                                      tvm::tir::StringImmNode::make(target));
+        global_region_func = WithAttr(std::move(global_region_func), 
attr::kInline,
+                            tvm::Integer(1));
+
+        // Constant propagation
+        if (!params_bind.empty()) {
+          global_region_func = backend::BindParamsByName(global_region_func, 
params_bind);
         }
-      }
 
-      auto subgraph_func =
-          Function(params, input, call->checked_type_, {});
-
-      std::string name = compiler_attrs->compiler + "_" + 
std::to_string(subgraph->id);
-      subgraph_func =
-          WithAttr(std::move(subgraph_func), attr::kExternalSymbol, 
tir::StringImmNode::make(name));
-      subgraph_func =
-          WithAttr(std::move(subgraph_func), attr::kPrimitive, 
tvm::Integer(1));
-      subgraph_func =
-          WithAttr(std::move(subgraph_func), attr::kCompiler,
-                   tvm::tir::StringImmNode::make(compiler_attrs->compiler));
-      subgraph_func =
-          WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1));
-
-      // Constant propagation
-      if (!params_bind.empty()) {
-        subgraph_func = backend::BindParamsByName(subgraph_func, params_bind);
+        std::string fname = name;
+        CHECK(!module_->ContainGlobalVar(fname))
+                << "Global function " << fname << " already exists";
+        // Create a global function and add it to the IRModule for the region.
+        // This way we lift the functions that should be handled by external
+        // codegen to the module scope and rely on the pass manager to prevent 
relay
+        // function level passes (i.e. simplify inference and fusion) 
optimizing it.
+        GlobalVar glob_func(fname);
+        module_->Add(glob_func, global_region_func);
+
+        // The return type of callnode is the same as the type of the
+        // compiler_end node.
+        auto ret = Call(glob_func, param_expr);
+        region_function_calls[region] = ret;
+
+        if (region->GetOutputs().size() == 1) {
+          // If there is only a single output; no need to add a tuplegetitem 
node
+          return Call(glob_func, param_expr);
 
 Review comment:
   no need to create another call, just `return ret`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to