This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 09eb508  [BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)
09eb508 is described below

commit 09eb50820b1629f77531acaad16701bf572293f0
Author: Trevor Morris <[email protected]>
AuthorDate: Wed Apr 15 13:33:31 2020 -0700

    [BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)
    
    * Fix duplicate output in partitiongraph
    
    * Add test case
    
    * Fix test_annotated_regions with duplicate compiler_end outputs
    
    * Revert "Fix duplicate output in partitiongraph"
    
    This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa.
    
    * Prevent duplicate outputs in Tuple in PartitionGraph
    
    * Fix lint
    
    * Add another test case for when regions are merged, and when TupleGetItem 
was duplicated
    
    * Pull GetFunctionOutput out of branch, improve description of 
GetFunctionOutput
    
    * Use std::move for GetFunctionOutput. Fix typo with testcase name
    
    * Use tvm.transform.Sequential
---
 src/relay/transforms/partition_graph.cc         | 226 +++++++++++++-----------
 tests/python/relay/test_pass_partition_graph.py | 135 ++++++++++++++
 2 files changed, 260 insertions(+), 101 deletions(-)

diff --git a/src/relay/transforms/partition_graph.cc 
b/src/relay/transforms/partition_graph.cc
index c8367fb..15ad60b 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
       // region_function_calls is map that maintains
       // (each annotated regions) --> created function
 
-      if (region_function_calls.find(region) != region_function_calls.end()) {
-        // This section is executed only if there are multiple outputs in the
-        // region Thus, the function is always created and at the end there
-        // would be a tuple node Therefore, we insert a tuple get item node.
-
-        // Use the already created tuple node
-        auto sg_call = region_function_calls[region];
-        int index = GetRetIdx(region, GetRef<Call>(call));
-        CHECK_NE(index, -1);
-
-        auto tuple_get_item_ = TupleGetItem(sg_call, index);
-        tuple_get_item_->checked_type_ = 
GetRef<Call>(call)->args[0]->checked_type_;
-        return std::move(tuple_get_item_);
-      } else {
-        // First time this region is encountered in the traversal
-        // Creating the function
-
-        Array<Expr> fields;
-
-        for (auto ret : region->GetOutputs()) {
-          auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
-          fields.push_back(ret_expr);
-        }
-        int index = GetRetIdx(region, GetRef<Call>(call));
-        CHECK_NE(index, -1);
-
-        Array<Var> params;
-        Array<Expr> param_expr;
-        std::unordered_map<std::string, runtime::NDArray> params_bind;
-
-        for (auto pair : region_args[region]) {
-          params.push_back(pair.first);
-          if (const auto* cn = pair.second.as<ConstantNode>()) {
-            params_bind[pair.first->name_hint()] = cn->data;
-          } else {
-            param_expr.push_back(pair.second);
-          }
-        }
-
-        Function global_region_func;
-        if (region->GetOutputs().size() == 1) {
-          // If there are only a single output; no need to add a tuple
-          global_region_func =
-              Function(params, fields[0], call->args[0]->checked_type_, {}, 
DictAttrs());
-        } else {
-          auto tuple = Tuple(fields);
-          global_region_func = Function(params, tuple, tuple->checked_type_, 
{}, DictAttrs());
-        }
-
-        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
-        std::string name = target + "_" + std::to_string(region->GetID());
-
-        global_region_func = WithAttr(std::move(global_region_func), 
tvm::attr::kGlobalSymbol,
-                                      runtime::String(name));
-        global_region_func =
-            WithAttr(std::move(global_region_func), attr::kPrimitive, 
tvm::Integer(1));
-        global_region_func = WithAttr(std::move(global_region_func), 
attr::kCompiler,
-                                      tvm::runtime::String(target));
-        global_region_func =
-            WithAttr(std::move(global_region_func), attr::kInline, 
tvm::Integer(1));
-
-        // Constant propagation
-        if (!params_bind.empty()) {
-          global_region_func = backend::BindParamsByName(global_region_func, 
params_bind);
-        }
-
-        std::string fname = name;
-        CHECK(!module_->ContainGlobalVar(fname))
-            << "Global function " << fname << " already exists";
-        // Create a global function and add it to the IRModule for the region.
-        // This way we lift the functions that should be handled by external
-        // codegen to the module scope and rely on the pass manager to prevent
-        // relay function level passes (i.e. simplify inference and fusion)
-        // optimizing it.
-        GlobalVar glob_func(fname);
-        module_->Add(glob_func, global_region_func);
-
-        // The return type of callnode is the same as the type of the
-        // compiler_end node.
-        auto ret = Call(glob_func, param_expr);
-        region_function_calls[region] = ret;
-
-        if (region->GetOutputs().size() == 1) {
-          // If there is only a single output; no need to add a tuplegetitem
-          // node
-          return std::move(ret);
-        } else {
-          // Add a tuplegetitem node to select this output out of many
-          auto tuple_get_item_ = TupleGetItem(ret, index);
-          tuple_get_item_->checked_type_ = 
GetRef<Call>(call)->args[0]->checked_type_;
-          return std::move(tuple_get_item_);
-        }
+      if (region_function_calls.find(region) == region_function_calls.end()) {
+        // First time this region is encountered in the traversal.
+        // Creating the function.
+        CreateFunction(region, call);
       }
+      // Retrieve this particular output of function.
+      return GetFunctionOutput(region, GetRef<Call>(call));
     }
   }
 
@@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
   }
 
   /*!
-   * \brief Get the index of the return(output);
-   * this is to be used as tuplegetitem idx
+   * \brief This function is called first time that we encounter a compiler_end
+   * node to create the function for the subgraph.
    */
-  int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
-    int idx = 0;
-    for (auto arg_ : sg->GetOutputs()) {
-      if (arg == arg_) {
-        return idx;
+  void CreateFunction(AnnotatedRegion region, const CallNode* call) {
+    // Create fields which is a unique list of outputs. Also populate
+    // region_return_indices_ map which maps parent of compiler_end node to
+    // corresponding index in fields.
+    Array<Expr> fields;
+    int i = 0;
+    for (auto ret : region->GetOutputs()) {
+      auto ret_node = Downcast<Call>(ret)->args[0];
+      // Don't duplicate outputs.
+      if (!region_return_indices_.count(region) ||
+          !region_return_indices_[region].count(ret_node)) {
+        auto ret_expr = VisitExpr(ret_node);
+        fields.push_back(ret_expr);
+        region_return_indices_[region][ret_node] = i;
+        i++;
       }
-      idx++;
     }
-    return -1;
+
+    Array<Var> params;
+    Array<Expr> param_expr;
+    std::unordered_map<std::string, runtime::NDArray> params_bind;
+
+    for (auto pair : region_args[region]) {
+      params.push_back(pair.first);
+      if (const auto* cn = pair.second.as<ConstantNode>()) {
+        params_bind[pair.first->name_hint()] = cn->data;
+      } else {
+        param_expr.push_back(pair.second);
+      }
+    }
+
+    Function global_region_func;
+    if (fields.size() == 1) {
+      // If there are only a single output; no need to add a tuple
+      global_region_func =
+          Function(params, fields[0], call->args[0]->checked_type_, {}, 
DictAttrs());
+    } else {
+      auto tuple = Tuple(fields);
+      global_region_func = Function(params, tuple, tuple->checked_type_, {}, 
DictAttrs());
+    }
+
+    std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+    std::string name = target + "_" + std::to_string(region->GetID());
+
+    global_region_func = WithAttr(std::move(global_region_func), 
tvm::attr::kGlobalSymbol,
+                                  runtime::String(name));
+    global_region_func =
+        WithAttr(std::move(global_region_func), attr::kPrimitive, 
tvm::Integer(1));
+    global_region_func = WithAttr(std::move(global_region_func), 
attr::kCompiler,
+                                  tvm::runtime::String(target));
+    global_region_func =
+        WithAttr(std::move(global_region_func), attr::kInline, 
tvm::Integer(1));
+
+    // Constant propagation
+    if (!params_bind.empty()) {
+      global_region_func = backend::BindParamsByName(global_region_func, 
params_bind);
+    }
+
+    std::string fname = name;
+    CHECK(!module_->ContainGlobalVar(fname))
+        << "Global function " << fname << " already exists";
+    // Create a global function and add it to the IRModule for the region.
+    // This way we lift the functions that should be handled by external
+    // codegen to the module scope and rely on the pass manager to prevent
+    // relay function level passes (i.e. simplify inference and fusion)
+    // optimizing it.
+    GlobalVar glob_func(fname);
+    module_->Add(glob_func, global_region_func);
+
+    // The return type of callnode is the same as the type of the
+    // compiler_end node.
+    auto ret = Call(glob_func, param_expr);
+    region_function_calls[region] = ret;
+  }
+
+  /*!
+   * \brief Get the return(output) of the function for compiler end node 
"end_arg".
+   * This will return either a Call (for a function with a single output) or a
+   * TupleGetItem (for a function with multiple outputs).
+   */
+  Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
+    Expr arg = Downcast<Call>(end_arg)->args[0];
+    // Function has one output.
+    if (region_return_indices_[region].size() == 1) {
+      return region_function_calls[region];
+    }
+    // Function has multiple outputs.
+    // Use already made TupleGetItem.
+    if (region_return_tuplegetitem_.count(region) &&
+        region_return_tuplegetitem_[region].count(arg)) {
+      return region_return_tuplegetitem_[region][arg];
+    }
+    // Create new TupleGetItem.
+    CHECK(region_return_indices_.count(region) &&
+          region_return_indices_[region].count(arg));
+    int index = region_return_indices_[region][arg];
+
+    auto func_call = region_function_calls[region];
+    auto tuple_get_item_ = TupleGetItem(func_call, index);
+    tuple_get_item_->checked_type_ = arg->checked_type_;
+    region_return_tuplegetitem_[region][arg] = tuple_get_item_;
+    return std::move(tuple_get_item_);
   }
 
   /*!
@@ -486,6 +493,23 @@ class Partitioner : public ExprMutator {
       region_args;
 
   /*!
+   * \brief This map maintains the index of an output in the subgraph function
+   * for a given region. If there are multiple entries for a region, then the
+   * function has a tuple of multiple outputs for its return.
+   */
+  using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, 
ObjectEqual>;
+  std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, 
ObjectEqual>
+      region_return_indices_;
+
+  /*!
+   * \brief This map holds already created TupleGetItem nodes for accessing
+   * outputs of a function.
+   */
+  using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, 
ObjectHash, ObjectEqual>;
+  std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, 
ObjectEqual>
+      region_return_tuplegetitem_;
+
+  /*!
    * \brief Each region set is associated with a function in the module.
    * This map maintains the mapping between regionsets and the function it
    * belongs to
diff --git a/tests/python/relay/test_pass_partition_graph.py 
b/tests/python/relay/test_pass_partition_graph.py
index 2ee8538..8827fbf 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -23,6 +23,7 @@ import pytest
 
 import tvm
 import tvm.relay.testing
+import tvm.relay.op as reg
 from tvm import relay
 from tvm import runtime
 from tvm.relay import transform
@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output():
     test_same_output_region()
     test_different_output_region()
 
+def test_duplicate_outputs():
+    target = "test_duplicate_outputs"
+
+    @reg.register("abs", "target." + target)
+    def abs(attrs, args): # pylint: disable=unused-variable
+        return True
+
+    def create_graph():
+        data = relay.var('data', shape=(10, 10))
+        x = relay.abs(data)
+        out_1 = relay.nn.relu(x)
+        out_2 = relay.tanh(x)
+        out_3 = relay.log(x)
+        out = relay.Tuple([out_1, out_2, out_3])
+        func = relay.Function([data], out)
+        return func
+
+    def expected():
+        mod = tvm.IRModule()
+
+        # function 0
+        f0_i0 = relay.var(target+"_0_i0", shape=(10, 10))
+        f0_o0 = relay.abs(f0_i0)
+        func0 = relay.Function([f0_i0], f0_o0)
+
+        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+        func0 = func0.with_attr("Compiler", target)
+        func0 = func0.with_attr("global_symbol", target+"_0")
+        gv0 = relay.GlobalVar(target+"_0")
+        mod[gv0] = func0
+
+        # body
+        data = relay.var('data', shape=(10, 10))
+        function_out = gv0(data)
+        out_1 = relay.nn.relu(function_out)
+        out_2 = relay.tanh(function_out)
+        out_3 = relay.log(function_out)
+        out = relay.Tuple([out_1, out_2, out_3])
+        func = relay.Function([data], out)
+        mod["main"] = func
+        return mod
+
+    mod = tvm.IRModule()
+    mod["main"] = create_graph()
+
+    seq = tvm.transform.Sequential([
+        transform.AnnotateTarget(target),
+        transform.MergeCompilerRegions(),
+        transform.PartitionGraph(),
+    ])
+
+    ref_mod = expected()
+    partitioned = seq(mod)
+    assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
+
+def test_duplicate_merge_and_tuplegetitem():
+    target = "test_duplicate_merge_and_tuplegetitem"
+
+    @reg.register("nn.batch_norm", "target." + target)
+    def abs(attrs, args): # pylint: disable=unused-variable
+        return True
+
+    @reg.register("nn.relu", "target." + target)
+    def abs(attrs, args): # pylint: disable=unused-variable
+        return True
+
+    def create_graph():
+        data = relay.var('data', shape=(10, 10))
+        bn_gamma = relay.var("bn_gamma")
+        bn_beta = relay.var("bn_beta")
+        bn_mmean = relay.var("bn_mean")
+        bn_mvar = relay.var("bn_var")
+        x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+        out_1 = relay.nn.relu(x[0])
+        bn_out_1 = x[1]
+        out_2 = relay.tanh(bn_out_1)
+        out_3 = relay.log(bn_out_1)
+        out = relay.Tuple([out_1, out_2, out_3])
+        func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], 
out)
+        return func
+
+    def expected():
+        mod = tvm.IRModule()
+
+        # function 0
+        f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
+        f0_i1 = relay.var(target+"_1_i1")
+        f0_i2 = relay.var(target+"_1_i2")
+        f0_i3 = relay.var(target+"_1_i3")
+        f0_i4 = relay.var(target+"_1_i4")
+        f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
+        f0_n1 = f0_n0[1]
+        f0_n2 = relay.nn.relu(f0_n0[0])
+        f0_o0 = relay.Tuple([f0_n1, f0_n2])
+        func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
+
+        func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+        func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+        func0 = func0.with_attr("Compiler", target)
+        func0 = func0.with_attr("global_symbol", target+"_1")
+        gv0 = relay.GlobalVar(target+"_1")
+        mod[gv0] = func0
+
+        # body
+        data = relay.var('data', shape=(10, 10))
+        bn_gamma = relay.var("bn_gamma")
+        bn_beta = relay.var("bn_beta")
+        bn_mmean = relay.var("bn_mean")
+        bn_mvar = relay.var("bn_var")
+        function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+        get_out0 = relay.TupleGetItem(function_out, 0)
+        get_out1 = relay.TupleGetItem(function_out, 1)
+        out_2 = relay.tanh(get_out0)
+        out_3 = relay.log(get_out0)
+        out = relay.Tuple([get_out1, out_2, out_3])
+        func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], 
out)
+        mod["main"] = func
+        return mod
+
+    mod = tvm.IRModule()
+    mod["main"] = create_graph()
+
+    seq = tvm.transform.Sequential([
+        transform.AnnotateTarget(target),
+        transform.MergeCompilerRegions(),
+        transform.PartitionGraph(),
+    ])
+
+    ref_mod = expected()
+    partitioned = seq(mod)
+    assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
 
 if __name__ == "__main__":
     test_multi_node_compiler()
@@ -1051,3 +1184,5 @@ if __name__ == "__main__":
     test_mixed_single_multiple_outputs()
     test_dnnl_fuse()
     test_multiple_use_of_an_output()
+    test_duplicate_outputs()
+    test_duplicate_merge_and_tuplegetitem()

Reply via email to