This is an automated email from the ASF dual-hosted git repository. zhic pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new 046b0d9 [BYOC] Bind constant tuples in graph partitioner (#5476) 046b0d9 is described below commit 046b0d98a08153a4829a12cc81a4fa856be6efcd Author: mbaret <55580676+mba...@users.noreply.github.com> AuthorDate: Wed Apr 29 19:23:15 2020 +0100 [BYOC] Bind constant tuples in graph partitioner (#5476) * Bind constant tuples in the graph partitioner Change-Id: I815b32b5445a536c1837369b04f67dbbb0aed900 * Add partitioning test Change-Id: I3a492ec8d1beab4830214e3bc8da2a7c80771ca4 * Rename test target Change-Id: Ie32f37c1395ff597c0047ad3a93ed04ce3f3125d --- src/relay/transforms/partition_graph.cc | 22 ++++++++++++--- tests/python/relay/test_pass_partition_graph.py | 37 +++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 15ad60b..3b0d6bc 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -393,12 +393,26 @@ class Partitioner : public ExprMutator { Array<Var> params; Array<Expr> param_expr; - std::unordered_map<std::string, runtime::NDArray> params_bind; + Map<Var, Expr> params_bind; + + auto IsConstant = [](const Expr& expr) { + if (expr->IsInstance<ConstantNode>()) + return true; + if (expr->IsInstance<TupleNode>()) { + auto tuple = expr.as<TupleNode>(); + for (const auto& field : tuple->fields) { + if (!field->IsInstance<ConstantNode>()) + return false; + } + return true; + } + return false; + }; for (auto pair : region_args[region]) { params.push_back(pair.first); - if (const auto* cn = pair.second.as<ConstantNode>()) { - params_bind[pair.first->name_hint()] = cn->data; + if (IsConstant(pair.second)) { + params_bind.Set(pair.first, pair.second); } else { param_expr.push_back(pair.second); } @@ -428,7 +442,7 @@ class Partitioner : public ExprMutator { // Constant propagation if (!params_bind.empty()) { - global_region_func = backend::BindParamsByName(global_region_func, params_bind); + global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind)); } std::string fname = name; diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 2a4fd31..d78b9ea 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1155,6 +1155,42 @@ def test_duplicate_merge_and_tuplegetitem(): partitioned = seq(mod) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) +def test_constant_tuples(): + @reg.register("qnn.concatenate", "target.const_tuples") + def add(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + a = relay.var('a', shape=(10, 10), dtype="uint8") + b = relay.var('b', shape=(10, 10), dtype="uint8") + a1 = relay.abs(a) + + zeroi = relay.const(1, "int32") + zerof = relay.const(0, "float32") + con = relay.qnn.op.concatenate((a1, b), + input_scales=(zerof, zerof), + input_zero_points=(zeroi, zeroi), + output_scale=zerof, + output_zero_point=zeroi, + axis=1) + + f = relay.Function([a, b], con) + mod = tvm.IRModule.from_expr(f) + return mod + + seq = tvm.transform.Sequential([ + transform.AnnotateTarget("const_tuples"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + partitioned = seq(create_graph()) + concat = partitioned["const_tuples_0"].body + assert type(concat.args[1]) == relay.Tuple + assert type(concat.args[2]) == relay.Tuple + assert type(concat.args[3]) == relay.Constant + assert type(concat.args[4]) == relay.Constant + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -1171,3 +1207,4 @@ if __name__ == "__main__": test_multiple_use_of_an_output() test_duplicate_outputs() test_duplicate_merge_and_tuplegetitem() + test_constant_tuples()