This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 09eb508 [BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)
09eb508 is described below
commit 09eb50820b1629f77531acaad16701bf572293f0
Author: Trevor Morris <[email protected]>
AuthorDate: Wed Apr 15 13:33:31 2020 -0700
[BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)
* Fix duplicate output in partitiongraph
* Add test case
* Fix test_annotated_regions with duplicate compiler_end outputs
* Revert "Fix duplicate output in partitiongraph"
This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa.
* Prevent duplicate outputs in Tuple in PartitionGraph
* Fix lint
* Add another test case for when regions are merged, and when TupleGetItem
was duplicated
* Pull GetFunctionOutput out of branch, improve description of
GetFunctionOutput
* Use std::move for GetFunctionOutput. Fix typo with testcase name
* Use tvm.transform.Sequential
---
src/relay/transforms/partition_graph.cc | 226 +++++++++++++-----------
tests/python/relay/test_pass_partition_graph.py | 135 ++++++++++++++
2 files changed, 260 insertions(+), 101 deletions(-)
diff --git a/src/relay/transforms/partition_graph.cc
b/src/relay/transforms/partition_graph.cc
index c8367fb..15ad60b 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
// region_function_calls is map that maintains
// (each annotated regions) --> created function
- if (region_function_calls.find(region) != region_function_calls.end()) {
- // This section is executed only if there are multiple outputs in the
- // region Thus, the function is always created and at the end there
- // would be a tuple node Therefore, we insert a tuple get item node.
-
- // Use the already created tuple node
- auto sg_call = region_function_calls[region];
- int index = GetRetIdx(region, GetRef<Call>(call));
- CHECK_NE(index, -1);
-
- auto tuple_get_item_ = TupleGetItem(sg_call, index);
- tuple_get_item_->checked_type_ =
GetRef<Call>(call)->args[0]->checked_type_;
- return std::move(tuple_get_item_);
- } else {
- // First time this region is encountered in the traversal
- // Creating the function
-
- Array<Expr> fields;
-
- for (auto ret : region->GetOutputs()) {
- auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
- fields.push_back(ret_expr);
- }
- int index = GetRetIdx(region, GetRef<Call>(call));
- CHECK_NE(index, -1);
-
- Array<Var> params;
- Array<Expr> param_expr;
- std::unordered_map<std::string, runtime::NDArray> params_bind;
-
- for (auto pair : region_args[region]) {
- params.push_back(pair.first);
- if (const auto* cn = pair.second.as<ConstantNode>()) {
- params_bind[pair.first->name_hint()] = cn->data;
- } else {
- param_expr.push_back(pair.second);
- }
- }
-
- Function global_region_func;
- if (region->GetOutputs().size() == 1) {
- // If there are only a single output; no need to add a tuple
- global_region_func =
- Function(params, fields[0], call->args[0]->checked_type_, {},
DictAttrs());
- } else {
- auto tuple = Tuple(fields);
- global_region_func = Function(params, tuple, tuple->checked_type_,
{}, DictAttrs());
- }
-
- std::string target = call->attrs.as<CompilerAttrs>()->compiler;
- std::string name = target + "_" + std::to_string(region->GetID());
-
- global_region_func = WithAttr(std::move(global_region_func),
tvm::attr::kGlobalSymbol,
- runtime::String(name));
- global_region_func =
- WithAttr(std::move(global_region_func), attr::kPrimitive,
tvm::Integer(1));
- global_region_func = WithAttr(std::move(global_region_func),
attr::kCompiler,
- tvm::runtime::String(target));
- global_region_func =
- WithAttr(std::move(global_region_func), attr::kInline,
tvm::Integer(1));
-
- // Constant propagation
- if (!params_bind.empty()) {
- global_region_func = backend::BindParamsByName(global_region_func,
params_bind);
- }
-
- std::string fname = name;
- CHECK(!module_->ContainGlobalVar(fname))
- << "Global function " << fname << " already exists";
- // Create a global function and add it to the IRModule for the region.
- // This way we lift the functions that should be handled by external
- // codegen to the module scope and rely on the pass manager to prevent
- // relay function level passes (i.e. simplify inference and fusion)
- // optimizing it.
- GlobalVar glob_func(fname);
- module_->Add(glob_func, global_region_func);
-
- // The return type of callnode is the same as the type of the
- // compiler_end node.
- auto ret = Call(glob_func, param_expr);
- region_function_calls[region] = ret;
-
- if (region->GetOutputs().size() == 1) {
- // If there is only a single output; no need to add a tuplegetitem
- // node
- return std::move(ret);
- } else {
- // Add a tuplegetitem node to select this output out of many
- auto tuple_get_item_ = TupleGetItem(ret, index);
- tuple_get_item_->checked_type_ =
GetRef<Call>(call)->args[0]->checked_type_;
- return std::move(tuple_get_item_);
- }
+ if (region_function_calls.find(region) == region_function_calls.end()) {
+ // First time this region is encountered in the traversal.
+ // Creating the function.
+ CreateFunction(region, call);
}
+ // Retrieve this particular output of function.
+ return GetFunctionOutput(region, GetRef<Call>(call));
}
}
@@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
}
/*!
- * \brief Get the index of the return(output);
- * this is to be used as tuplegetitem idx
+ * \brief This function is called first time that we encounter a compiler_end
+ * node to create the function for the subgraph.
*/
- int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
- int idx = 0;
- for (auto arg_ : sg->GetOutputs()) {
- if (arg == arg_) {
- return idx;
+ void CreateFunction(AnnotatedRegion region, const CallNode* call) {
+ // Create fields which is a unique list of outputs. Also populate
+ // region_return_indices_ map which maps parent of compiler_end node to
+ // corresponding index in fields.
+ Array<Expr> fields;
+ int i = 0;
+ for (auto ret : region->GetOutputs()) {
+ auto ret_node = Downcast<Call>(ret)->args[0];
+ // Don't duplicate outputs.
+ if (!region_return_indices_.count(region) ||
+ !region_return_indices_[region].count(ret_node)) {
+ auto ret_expr = VisitExpr(ret_node);
+ fields.push_back(ret_expr);
+ region_return_indices_[region][ret_node] = i;
+ i++;
}
- idx++;
}
- return -1;
+
+ Array<Var> params;
+ Array<Expr> param_expr;
+ std::unordered_map<std::string, runtime::NDArray> params_bind;
+
+ for (auto pair : region_args[region]) {
+ params.push_back(pair.first);
+ if (const auto* cn = pair.second.as<ConstantNode>()) {
+ params_bind[pair.first->name_hint()] = cn->data;
+ } else {
+ param_expr.push_back(pair.second);
+ }
+ }
+
+ Function global_region_func;
+ if (fields.size() == 1) {
+ // If there are only a single output; no need to add a tuple
+ global_region_func =
+ Function(params, fields[0], call->args[0]->checked_type_, {},
DictAttrs());
+ } else {
+ auto tuple = Tuple(fields);
+ global_region_func = Function(params, tuple, tuple->checked_type_, {},
DictAttrs());
+ }
+
+ std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+ std::string name = target + "_" + std::to_string(region->GetID());
+
+ global_region_func = WithAttr(std::move(global_region_func),
tvm::attr::kGlobalSymbol,
+ runtime::String(name));
+ global_region_func =
+ WithAttr(std::move(global_region_func), attr::kPrimitive,
tvm::Integer(1));
+ global_region_func = WithAttr(std::move(global_region_func),
attr::kCompiler,
+ tvm::runtime::String(target));
+ global_region_func =
+ WithAttr(std::move(global_region_func), attr::kInline,
tvm::Integer(1));
+
+ // Constant propagation
+ if (!params_bind.empty()) {
+ global_region_func = backend::BindParamsByName(global_region_func,
params_bind);
+ }
+
+ std::string fname = name;
+ CHECK(!module_->ContainGlobalVar(fname))
+ << "Global function " << fname << " already exists";
+ // Create a global function and add it to the IRModule for the region.
+ // This way we lift the functions that should be handled by external
+ // codegen to the module scope and rely on the pass manager to prevent
+ // relay function level passes (i.e. simplify inference and fusion)
+ // optimizing it.
+ GlobalVar glob_func(fname);
+ module_->Add(glob_func, global_region_func);
+
+ // The return type of callnode is the same as the type of the
+ // compiler_end node.
+ auto ret = Call(glob_func, param_expr);
+ region_function_calls[region] = ret;
+ }
+
+ /*!
+ * \brief Get the return(output) of the function for compiler end node
"end_arg".
+ * This will return either a Call (for a function with a single output) or a
+ * TupleGetItem (for a function with multiple outputs).
+ */
+ Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
+ Expr arg = Downcast<Call>(end_arg)->args[0];
+ // Function has one output.
+ if (region_return_indices_[region].size() == 1) {
+ return region_function_calls[region];
+ }
+ // Function has multiple outputs.
+ // Use already made TupleGetItem.
+ if (region_return_tuplegetitem_.count(region) &&
+ region_return_tuplegetitem_[region].count(arg)) {
+ return region_return_tuplegetitem_[region][arg];
+ }
+ // Create new TupleGetItem.
+ CHECK(region_return_indices_.count(region) &&
+ region_return_indices_[region].count(arg));
+ int index = region_return_indices_[region][arg];
+
+ auto func_call = region_function_calls[region];
+ auto tuple_get_item_ = TupleGetItem(func_call, index);
+ tuple_get_item_->checked_type_ = arg->checked_type_;
+ region_return_tuplegetitem_[region][arg] = tuple_get_item_;
+ return std::move(tuple_get_item_);
}
/*!
@@ -486,6 +493,23 @@ class Partitioner : public ExprMutator {
region_args;
/*!
+ * \brief This map maintains the index of an output in the subgraph function
+ * for a given region. If there are multiple entries for a region, then the
+ * function has a tuple of multiple outputs for its return.
+ */
+ using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash,
ObjectEqual>;
+ std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash,
ObjectEqual>
+ region_return_indices_;
+
+ /*!
+ * \brief This map holds already created TupleGetItem nodes for accessing
+ * outputs of a function.
+ */
+ using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem,
ObjectHash, ObjectEqual>;
+ std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash,
ObjectEqual>
+ region_return_tuplegetitem_;
+
+ /*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it
* belongs to
diff --git a/tests/python/relay/test_pass_partition_graph.py
b/tests/python/relay/test_pass_partition_graph.py
index 2ee8538..8827fbf 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -23,6 +23,7 @@ import pytest
import tvm
import tvm.relay.testing
+import tvm.relay.op as reg
from tvm import relay
from tvm import runtime
from tvm.relay import transform
@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output():
test_same_output_region()
test_different_output_region()
+def test_duplicate_outputs():
+ target = "test_duplicate_outputs"
+
+ @reg.register("abs", "target." + target)
+ def abs(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ def create_graph():
+ data = relay.var('data', shape=(10, 10))
+ x = relay.abs(data)
+ out_1 = relay.nn.relu(x)
+ out_2 = relay.tanh(x)
+ out_3 = relay.log(x)
+ out = relay.Tuple([out_1, out_2, out_3])
+ func = relay.Function([data], out)
+ return func
+
+ def expected():
+ mod = tvm.IRModule()
+
+ # function 0
+ f0_i0 = relay.var(target+"_0_i0", shape=(10, 10))
+ f0_o0 = relay.abs(f0_i0)
+ func0 = relay.Function([f0_i0], f0_o0)
+
+ func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+ func0 = func0.with_attr("Compiler", target)
+ func0 = func0.with_attr("global_symbol", target+"_0")
+ gv0 = relay.GlobalVar(target+"_0")
+ mod[gv0] = func0
+
+ # body
+ data = relay.var('data', shape=(10, 10))
+ function_out = gv0(data)
+ out_1 = relay.nn.relu(function_out)
+ out_2 = relay.tanh(function_out)
+ out_3 = relay.log(function_out)
+ out = relay.Tuple([out_1, out_2, out_3])
+ func = relay.Function([data], out)
+ mod["main"] = func
+ return mod
+
+ mod = tvm.IRModule()
+ mod["main"] = create_graph()
+
+ seq = tvm.transform.Sequential([
+ transform.AnnotateTarget(target),
+ transform.MergeCompilerRegions(),
+ transform.PartitionGraph(),
+ ])
+
+ ref_mod = expected()
+ partitioned = seq(mod)
+ assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
+
+def test_duplicate_merge_and_tuplegetitem():
+ target = "test_duplicate_merge_and_tuplegetitem"
+
+ @reg.register("nn.batch_norm", "target." + target)
+ def abs(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ @reg.register("nn.relu", "target." + target)
+ def abs(attrs, args): # pylint: disable=unused-variable
+ return True
+
+ def create_graph():
+ data = relay.var('data', shape=(10, 10))
+ bn_gamma = relay.var("bn_gamma")
+ bn_beta = relay.var("bn_beta")
+ bn_mmean = relay.var("bn_mean")
+ bn_mvar = relay.var("bn_var")
+ x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+ out_1 = relay.nn.relu(x[0])
+ bn_out_1 = x[1]
+ out_2 = relay.tanh(bn_out_1)
+ out_3 = relay.log(bn_out_1)
+ out = relay.Tuple([out_1, out_2, out_3])
+ func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar],
out)
+ return func
+
+ def expected():
+ mod = tvm.IRModule()
+
+ # function 0
+ f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
+ f0_i1 = relay.var(target+"_1_i1")
+ f0_i2 = relay.var(target+"_1_i2")
+ f0_i3 = relay.var(target+"_1_i3")
+ f0_i4 = relay.var(target+"_1_i4")
+ f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
+ f0_n1 = f0_n0[1]
+ f0_n2 = relay.nn.relu(f0_n0[0])
+ f0_o0 = relay.Tuple([f0_n1, f0_n2])
+ func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
+
+ func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+ func0 = func0.with_attr("Compiler", target)
+ func0 = func0.with_attr("global_symbol", target+"_1")
+ gv0 = relay.GlobalVar(target+"_1")
+ mod[gv0] = func0
+
+ # body
+ data = relay.var('data', shape=(10, 10))
+ bn_gamma = relay.var("bn_gamma")
+ bn_beta = relay.var("bn_beta")
+ bn_mmean = relay.var("bn_mean")
+ bn_mvar = relay.var("bn_var")
+ function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+ get_out0 = relay.TupleGetItem(function_out, 0)
+ get_out1 = relay.TupleGetItem(function_out, 1)
+ out_2 = relay.tanh(get_out0)
+ out_3 = relay.log(get_out0)
+ out = relay.Tuple([get_out1, out_2, out_3])
+ func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar],
out)
+ mod["main"] = func
+ return mod
+
+ mod = tvm.IRModule()
+ mod["main"] = create_graph()
+
+ seq = tvm.transform.Sequential([
+ transform.AnnotateTarget(target),
+ transform.MergeCompilerRegions(),
+ transform.PartitionGraph(),
+ ])
+
+ ref_mod = expected()
+ partitioned = seq(mod)
+ assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
if __name__ == "__main__":
test_multi_node_compiler()
@@ -1051,3 +1184,5 @@ if __name__ == "__main__":
test_mixed_single_multiple_outputs()
test_dnnl_fuse()
test_multiple_use_of_an_output()
+ test_duplicate_outputs()
+ test_duplicate_merge_and_tuplegetitem()