This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8ea976276c [Unity] Properly handle tuple-outputting function in 
`FuseOpsByPattern` (#14525)
8ea976276c is described below

commit 8ea976276cde885359280ec1a7d7d78d4b2c06bb
Author: masahi <[email protected]>
AuthorDate: Fri Apr 7 21:25:18 2023 +0900

    [Unity] Properly handle tuple-outputting function in `FuseOpsByPattern` 
(#14525)
    
    * correctly handle tuple output function
    
    * properly handle nested tuple case
    
    * typo
    
    * fix MergeCompositFunction test
---
 src/relax/transform/fuse_ops.cc                    | 28 ++++++--
 .../relax/test_transform_fuse_ops_by_pattern.py    | 79 +++++++++++++++++++++-
 2 files changed, 102 insertions(+), 5 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 8e4346e206..cf9bd0ac37 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -659,6 +659,16 @@ class OperatorFusor : public ExprMutator {
     return sinfo->ret->IsInstance<TupleStructInfoNode>();
   }
 
+  bool IsNestedTupleOutput(Function f) {
+    if (!IsTupleOutput(f)) return false;
+
+    auto tup = 
GetStructInfo(f).as<FuncStructInfoNode>()->ret.as<TupleStructInfoNode>();
+    for (const auto& field : tup->fields) {
+      if (field->IsInstance<TupleStructInfoNode>()) return true;
+    }
+    return false;
+  }
+
   BindingBlock VisitBindingBlock(const BindingBlock& block) final {
     if (const auto* df_block = block.as<DataflowBlockNode>()) {
       return VisitBindingBlock_(df_block);
@@ -722,7 +732,12 @@ class OperatorFusor : public ExprMutator {
       // needs to be remapped to the output of TupleGetItem after the 
corresponding tuple is
       // emitted.
       if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get())) 
{
-        pending_tuple_get[group].push_back(binding->var);
+        if (!GetStructInfo(binding->var)->IsInstance<TupleStructInfoNode>() ||
+            IsNestedTupleOutput(func)) {
+          // When binding->var itself is a tuple, we do not need to remap this 
variable to the
+          // output of TupleGetItem unless the output is a nested tuple.
+          pending_tuple_get[group].push_back(binding->var);
+        }
       }
 
       // Case 2. If the binding is not the last binding of the group, we skip 
it.
@@ -751,7 +766,7 @@ class OperatorFusor : public ExprMutator {
       }
 
       // Step c. Update the mapping used for the remapping of the binding 
variables.
-      if (IsTupleOutput(func)) {
+      if (IsTupleOutput(func) && !pending_tuple_get.empty()) {
         // If the output is a tuple, attach TupleGetItem to all tuple 
elements, and
         // remap variables approriately.
         // The variables that need to be remapped and the corresponding tuple 
indices are
@@ -1018,8 +1033,13 @@ class PatternBasedPartitioner : ExprVisitor {
       ICHECK(parent_group);
       parent_group->attrs.Set(attr::kComposite, pat_name_);
       for (const auto& [pat, match] : matches_opt.value()) {
-        // Put all matching call nodes into the parent group.
-        if (pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call)) 
{
+        // Put all matching expressions into the parent group. But we need to 
be careful not to
+        // merge expressions matched by a wildcard pattern, since a wildcard 
can match an output of
+        // the previous group. For example, when there are two back-to-back 
conv2d ops, the output
+        // of the first conv2d is matched to the input of the second conv2d 
via a wildcard pattern.
+        // But we must avoid merging the first conv2d into the group of the 
second conv2d.
+        if ((pat->IsInstance<CallPatternNode>() && match != 
GetRef<Call>(call)) ||
+            pat->IsInstance<TupleGetItemPatternNode>()) {
           // Put the bound variable on the LHS into the same parent group.
           AddToGroup(value_to_bound_var_[match], parent_group);
         }
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 2f3e2d479f..146c8e1ebc 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -19,7 +19,12 @@ import pytest
 
 import tvm
 from tvm import relax
-from tvm.relax.dpl.pattern import is_op, make_fused_bias_activation_pattern, 
wildcard
+from tvm.relax.dpl.pattern import (
+    is_op,
+    make_fused_bias_activation_pattern,
+    wildcard,
+    is_tuple_get_item,
+)
 from tvm.relax.transform import PatternCheckContext
 from tvm.script import ir as I
 from tvm.script import relax as R
@@ -671,5 +676,77 @@ def test_bind_constants():
     )
 
 
+def test_split():
+    @R.function
+    def func(inp: R.Tensor((16, 32), "float32")):
+        with R.dataflow():
+            tup = R.split(inp, [16], axis=1)
+            out = R.add(tup[0], tup[1])
+            R.output(out)
+        return out
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def fused_relax_split(
+            inp: R.Tensor((16, 32), dtype="float32")
+        ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), 
dtype="float32")):
+            R.func_attr({"Composite": "x.split", "Primitive": 1})
+            with R.dataflow():
+                gv: R.Tuple(
+                    R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), 
dtype="float32")
+                ) = R.split(inp, indices_or_sections=[16], axis=1)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16, 
16), dtype="float32"):
+            cls = Expected1
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), 
dtype="float32")
+                ) = cls.fused_relax_split(inp)
+                lv1: R.Tensor((16, 16), dtype="float32") = lv[0]
+                lv2: R.Tensor((16, 16), dtype="float32") = lv[1]
+                out: R.Tensor((16, 16), dtype="float32") = R.add(lv1, lv2)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def fused_relax_split_relax_add(
+            inp: R.Tensor((16, 32), dtype="float32")
+        ) -> R.Tensor((16, 16), dtype="float32"):
+            R.func_attr({"Composite": "x.split", "Primitive": 1})
+            with R.dataflow():
+                tup: R.Tuple(
+                    R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), 
dtype="float32")
+                ) = R.split(inp, indices_or_sections=[16], axis=1)
+                lv1: R.Tensor((16, 16), dtype="float32") = tup[0]
+                lv2: R.Tensor((16, 16), dtype="float32") = tup[1]
+                gv: R.Tensor((16, 16), dtype="float32") = R.add(lv1, lv2)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16, 
16), dtype="float32"):
+            cls = Expected2
+            with R.dataflow():
+                gv: R.Tensor((16, 16), dtype="float32") = 
cls.fused_relax_split_relax_add(inp)
+                R.output(gv)
+            return gv
+
+    mod = tvm.IRModule({"main": func})
+
+    split = is_op("relax.split")(wildcard())
+    it1 = is_tuple_get_item(split, 0)
+    it2 = is_tuple_get_item(split, 1)
+    add = is_op("relax.add")(it1, it2)
+
+    check(mod, [("x.split", split)], Expected1)
+    check(mod, [("x.split", add)], Expected2)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to