This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 046b0d9  [BYOC] Bind constant tuples in graph partitioner (#5476)
046b0d9 is described below

commit 046b0d98a08153a4829a12cc81a4fa856be6efcd
Author: mbaret <55580676+mba...@users.noreply.github.com>
AuthorDate: Wed Apr 29 19:23:15 2020 +0100

    [BYOC] Bind constant tuples in graph partitioner (#5476)
    
    * Bind constant tuples in the graph partitioner
    
    Change-Id: I815b32b5445a536c1837369b04f67dbbb0aed900
    
    * Add partitioning test
    
    Change-Id: I3a492ec8d1beab4830214e3bc8da2a7c80771ca4
    
    * Rename test target
    
    Change-Id: Ie32f37c1395ff597c0047ad3a93ed04ce3f3125d
---
 src/relay/transforms/partition_graph.cc         | 22 ++++++++++++---
 tests/python/relay/test_pass_partition_graph.py | 37 +++++++++++++++++++++++++
 2 files changed, 55 insertions(+), 4 deletions(-)

diff --git a/src/relay/transforms/partition_graph.cc 
b/src/relay/transforms/partition_graph.cc
index 15ad60b..3b0d6bc 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -393,12 +393,26 @@ class Partitioner : public ExprMutator {
 
     Array<Var> params;
     Array<Expr> param_expr;
-    std::unordered_map<std::string, runtime::NDArray> params_bind;
+    Map<Var, Expr> params_bind;
+
+    auto IsConstant = [](const Expr& expr) {
+      if (expr->IsInstance<ConstantNode>())
+        return true;
+      if (expr->IsInstance<TupleNode>()) {
+        auto tuple = expr.as<TupleNode>();
+        for (const auto& field : tuple->fields) {
+          if (!field->IsInstance<ConstantNode>())
+            return false;
+        }
+        return true;
+      }
+      return false;
+    };
 
     for (auto pair : region_args[region]) {
       params.push_back(pair.first);
-      if (const auto* cn = pair.second.as<ConstantNode>()) {
-        params_bind[pair.first->name_hint()] = cn->data;
+      if (IsConstant(pair.second)) {
+        params_bind.Set(pair.first, pair.second);
       } else {
         param_expr.push_back(pair.second);
       }
@@ -428,7 +442,7 @@ class Partitioner : public ExprMutator {
 
     // Constant propagation
     if (!params_bind.empty()) {
-      global_region_func = backend::BindParamsByName(global_region_func, 
params_bind);
+      global_region_func = Downcast<Function>(relay::Bind(global_region_func, 
params_bind));
     }
 
     std::string fname = name;
diff --git a/tests/python/relay/test_pass_partition_graph.py 
b/tests/python/relay/test_pass_partition_graph.py
index 2a4fd31..d78b9ea 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -1155,6 +1155,42 @@ def test_duplicate_merge_and_tuplegetitem():
     partitioned = seq(mod)
     assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
 
+def test_constant_tuples():
+    @reg.register("qnn.concatenate", "target.const_tuples")
+    def add(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    def create_graph():
+        a = relay.var('a', shape=(10, 10), dtype="uint8")
+        b = relay.var('b', shape=(10, 10), dtype="uint8")
+        a1 = relay.abs(a)
+
+        zeroi = relay.const(1, "int32")
+        zerof = relay.const(0, "float32")
+        con = relay.qnn.op.concatenate((a1, b),
+                                       input_scales=(zerof, zerof),
+                                       input_zero_points=(zeroi, zeroi),
+                                       output_scale=zerof,
+                                       output_zero_point=zeroi,
+                                       axis=1)
+
+        f = relay.Function([a, b], con)
+        mod = tvm.IRModule.from_expr(f)
+        return mod
+
+    seq = tvm.transform.Sequential([
+        transform.AnnotateTarget("const_tuples"),
+        transform.MergeCompilerRegions(),
+        transform.PartitionGraph(),
+    ])
+
+    partitioned = seq(create_graph())
+    concat = partitioned["const_tuples_0"].body
+    assert type(concat.args[1]) == relay.Tuple
+    assert type(concat.args[2]) == relay.Tuple
+    assert type(concat.args[3]) == relay.Constant
+    assert type(concat.args[4]) == relay.Constant
+
 if __name__ == "__main__":
     test_multi_node_compiler()
     test_extern_ccompiler_single_op()
@@ -1171,3 +1207,4 @@ if __name__ == "__main__":
     test_multiple_use_of_an_output()
     test_duplicate_outputs()
     test_duplicate_merge_and_tuplegetitem()
+    test_constant_tuples()

Reply via email to