trevor-m commented on a change in pull request #5320: [BYOC] Prevent duplicate
outputs in subgraph Tuple
URL: https://github.com/apache/incubator-tvm/pull/5320#discussion_r408368549
##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -206,97 +206,16 @@ class Partitioner : public ExprMutator {
// (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) {
- // This section is executed only if there are multiple outputs in the
- // region Thus, the function is always created and at the end there
- // would be a tuple node Therefore, we insert a tuple get item node.
-
- // Use the already created tuple node
- auto sg_call = region_function_calls[region];
- int index = GetRetIdx(region, GetRef<Call>(call));
- CHECK_NE(index, -1);
-
- auto tuple_get_item_ = TupleGetItem(sg_call, index);
- tuple_get_item_->checked_type_ =
GetRef<Call>(call)->args[0]->checked_type_;
- return std::move(tuple_get_item_);
+ // This section is executed if there are multiple outputs in the region
+ // or if the output of the function is being accessed multiple times by
+ // different nodes.
+ return GetFunctionOutput(region, GetRef<Call>(call));
} else {
- // First time this region is encountered in the traversal
- // Creating the function
-
- Array<Expr> fields;
-
- for (auto ret : region->GetOutputs()) {
- auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
- fields.push_back(ret_expr);
- }
- int index = GetRetIdx(region, GetRef<Call>(call));
- CHECK_NE(index, -1);
-
- Array<Var> params;
- Array<Expr> param_expr;
- std::unordered_map<std::string, runtime::NDArray> params_bind;
-
- for (auto pair : region_args[region]) {
- params.push_back(pair.first);
- if (const auto* cn = pair.second.as<ConstantNode>()) {
- params_bind[pair.first->name_hint()] = cn->data;
- } else {
- param_expr.push_back(pair.second);
- }
- }
-
- Function global_region_func;
- if (region->GetOutputs().size() == 1) {
- // If there are only a single output; no need to add a tuple
- global_region_func =
- Function(params, fields[0], call->args[0]->checked_type_, {},
DictAttrs());
- } else {
- auto tuple = Tuple(fields);
- global_region_func = Function(params, tuple, tuple->checked_type_,
{}, DictAttrs());
- }
-
- std::string target = call->attrs.as<CompilerAttrs>()->compiler;
- std::string name = target + "_" + std::to_string(region->GetID());
-
- global_region_func = WithAttr(std::move(global_region_func),
tvm::attr::kGlobalSymbol,
- runtime::String(name));
- global_region_func =
- WithAttr(std::move(global_region_func), attr::kPrimitive,
tvm::Integer(1));
- global_region_func = WithAttr(std::move(global_region_func),
attr::kCompiler,
- tvm::runtime::String(target));
- global_region_func =
- WithAttr(std::move(global_region_func), attr::kInline,
tvm::Integer(1));
-
- // Constant propagation
- if (!params_bind.empty()) {
- global_region_func = backend::BindParamsByName(global_region_func,
params_bind);
- }
-
- std::string fname = name;
- CHECK(!module_->ContainGlobalVar(fname))
- << "Global function " << fname << " already exists";
- // Create a global function and add it to the IRModule for the region.
- // This way we lift the functions that should be handled by external
- // codegen to the module scope and rely on the pass manager to prevent
- // relay function level passes (i.e. simplify inference and fusion)
- // optimizing it.
- GlobalVar glob_func(fname);
- module_->Add(glob_func, global_region_func);
-
- // The return type of callnode is the same as the type of the
- // compiler_end node.
- auto ret = Call(glob_func, param_expr);
- region_function_calls[region] = ret;
-
- if (region->GetOutputs().size() == 1) {
- // If there is only a single output; no need to add a tuplegetitem
- // node
- return std::move(ret);
- } else {
- // Add a tuplegetitem node to select this output out of many
- auto tuple_get_item_ = TupleGetItem(ret, index);
- tuple_get_item_->checked_type_ =
GetRef<Call>(call)->args[0]->checked_type_;
- return std::move(tuple_get_item_);
- }
+ // First time this region is encountered in the traversal.
+ // Creating the function.
+ CreateFunction(region, call);
+ // Retrieve particular output.
+ return GetFunctionOutput(region, GetRef<Call>(call));
Review comment:
Thanks! Done.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services