This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 66863f53f6 [Unity] Fix MergeCompositeFunctions for non-CallNode 
dataflow inputs  (#14959)
66863f53f6 is described below

commit 66863f53f6687c2bc1240b731e6efc79349ebb57
Author: masahi <[email protected]>
AuthorDate: Fri May 26 16:33:13 2023 +0900

    [Unity] Fix MergeCompositeFunctions for non-CallNode dataflow inputs  
(#14959)
    
    * fix
    
    * add test
    
    * lint
---
 src/relax/transform/fuse_ops.cc                    |   3 +-
 src/relax/transform/merge_composite_functions.cc   |  11 +++
 .../test_transform_merge_composite_functions.py    | 110 +++++++++++++++++++++
 3 files changed, 122 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 8940768ced..1942fbddfe 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -402,9 +402,8 @@ class FunctionCreator : public ExprMutator {
             CheckDefAndUpdateParam(arg);
           }
         }
-      } else {
+      } else if (var_binding->value.as<TupleGetItemNode>()) {
         const auto* tuple_item = var_binding->value.as<TupleGetItemNode>();
-        ICHECK(tuple_item != nullptr);
         CheckDefAndUpdateParam(tuple_item->tuple);
       }
 
diff --git a/src/relax/transform/merge_composite_functions.cc 
b/src/relax/transform/merge_composite_functions.cc
index 81ee2ac7a1..0bc92ba923 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -84,6 +84,17 @@ class CompositeGroupsBuilder : public 
MemoizedExprTranslator<Group*> {
     for (const auto& param : func->params) {
       memo_[param] = arena_->make<Group>();
     }
+
+    PostOrderVisit(func, [this](Expr e) {
+      // Make default groups for dataflow nodes other than CallNode.
+      // Groups for CallNode are created in its visitor.
+      if (e->IsInstance<ConstantNode>() || e->IsInstance<ShapeExprNode>() ||
+          e->IsInstance<TupleNode>() || e->IsInstance<TupleGetItemNode>() ||
+          e->IsInstance<PrimValueNode>()) {
+        memo_[e] = arena_->make<Group>();
+      }
+    });
+
     VisitExpr(func->body);
 
     GroupMap group_map;
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py 
b/tests/python/relax/test_transform_merge_composite_functions.py
index d5a3b1aa59..61df388c78 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -19,6 +19,7 @@ import pytest
 import tvm
 from tvm import relax
 from tvm.script import relax as R
+from tvm.script import ir as I
 
 
 @tvm.script.ir_module
@@ -1066,5 +1067,114 @@ def test_mixed_non_composite():
     check(ModuleWithNonComposite, ModuleWithNonComposite_ref)
 
 
+def test_reshape():
+    # Verify that the non-CallNode input (shape in reshape) can be handled 
properly.
+    @I.ir_module
+    class Module:
+        @R.function
+        def fused_relax_matmul(
+            lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512), 
dtype="float32")
+        ) -> R.Tensor((1, 512), dtype="float32"):
+            R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 512), dtype="float32") = R.matmul(lv, lv1, 
out_dtype="float32")
+                R.output(gv)
+            return gv
+
+        @R.function
+        def fused_relax_reshape(
+            inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0: 
R.Shape([1, 784])
+        ) -> R.Tensor((1, 784), dtype="float32"):
+            R.func_attr({"Composite": "tensorrt.reshape", "Primitive": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 784), dtype="float32") = R.reshape(inp_0, 
param_0)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
+            linear_relu_stack_0_weight: R.Tensor((512, 784), dtype="float32"),
+        ) -> R.Tensor((1, 512), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                lv: R.Tensor((1, 784), dtype="float32") = 
cls.fused_relax_reshape(
+                    inp_0, R.shape([1, 784])
+                )
+                lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(
+                    linear_relu_stack_0_weight, axes=None
+                )
+                lv_1: R.Tensor((1, 512), dtype="float32") = 
cls.fused_relax_matmul(lv, lv1)
+                gv: R.Tensor((1, 512), dtype="float32") = lv_1
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def fused_relax_reshape_relax_matmul(
+            inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
+            param_0: R.Shape([1, 784]),
+            lv1: R.Tensor((784, 512), dtype="float32"),
+        ) -> R.Tensor((1, 512), dtype="float32"):
+            R.func_attr(
+                {
+                    "Codegen": "tensorrt",
+                    "Primitive": 1,
+                    "global_symbol": "fused_relax_reshape_relax_matmul",
+                }
+            )
+            with R.dataflow():
+                # from tvm.script import relax as R
+
+                @R.function
+                def lv_1(
+                    inp_0_1: R.Tensor((1, 1, 28, 28), dtype="float32"), 
param_0_1: R.Shape([1, 784])
+                ) -> R.Tensor((1, 784), dtype="float32"):
+                    R.func_attr({"Composite": "tensorrt.reshape", "Primitive": 
1})
+                    with R.dataflow():
+                        gv: R.Tensor((1, 784), dtype="float32") = 
R.reshape(inp_0_1, param_0_1)
+                        R.output(gv)
+                    return gv
+
+                lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, 
param_0)
+
+                @R.function
+                def lv1_1_1(
+                    lv_2: R.Tensor((1, 784), dtype="float32"),
+                    lv1_2: R.Tensor((784, 512), dtype="float32"),
+                ) -> R.Tensor((1, 512), dtype="float32"):
+                    R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 
1})
+                    with R.dataflow():
+                        gv: R.Tensor((1, 512), dtype="float32") = R.matmul(
+                            lv_2, lv1_2, out_dtype="float32"
+                        )
+                        R.output(gv)
+                    return gv
+
+                lv_2: R.Tensor((1, 512), dtype="float32") = lv1_1_1(lv_1, lv1)
+                gv: R.Tensor((1, 512), dtype="float32") = lv_2
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
+            linear_relu_stack_0_weight: R.Tensor((512, 784), dtype="float32"),
+        ) -> R.Tensor((1, 512), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(
+                    linear_relu_stack_0_weight, axes=None
+                )
+                gv: R.Tensor((1, 512), dtype="float32") = 
cls.fused_relax_reshape_relax_matmul(
+                    inp_0, R.shape([1, 784]), lv1
+                )
+                R.output(gv)
+            return gv
+
+    check(Module, Expected)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to