This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 66863f53f6 [Unity] Fix MergeCompositeFunctions for non-CallNode
dataflow inputs (#14959)
66863f53f6 is described below
commit 66863f53f6687c2bc1240b731e6efc79349ebb57
Author: masahi <[email protected]>
AuthorDate: Fri May 26 16:33:13 2023 +0900
[Unity] Fix MergeCompositeFunctions for non-CallNode dataflow inputs
(#14959)
* fix
* add test
* lint
---
src/relax/transform/fuse_ops.cc | 3 +-
src/relax/transform/merge_composite_functions.cc | 11 +++
.../test_transform_merge_composite_functions.py | 110 +++++++++++++++++++++
3 files changed, 122 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 8940768ced..1942fbddfe 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -402,9 +402,8 @@ class FunctionCreator : public ExprMutator {
CheckDefAndUpdateParam(arg);
}
}
- } else {
+ } else if (var_binding->value.as<TupleGetItemNode>()) {
const auto* tuple_item = var_binding->value.as<TupleGetItemNode>();
- ICHECK(tuple_item != nullptr);
CheckDefAndUpdateParam(tuple_item->tuple);
}
diff --git a/src/relax/transform/merge_composite_functions.cc
b/src/relax/transform/merge_composite_functions.cc
index 81ee2ac7a1..0bc92ba923 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -84,6 +84,17 @@ class CompositeGroupsBuilder : public
MemoizedExprTranslator<Group*> {
for (const auto& param : func->params) {
memo_[param] = arena_->make<Group>();
}
+
+ PostOrderVisit(func, [this](Expr e) {
+ // Make default groups for dataflow nodes other than CallNode.
+ // Groups for CallNode are created in its visitor.
+ if (e->IsInstance<ConstantNode>() || e->IsInstance<ShapeExprNode>() ||
+ e->IsInstance<TupleNode>() || e->IsInstance<TupleGetItemNode>() ||
+ e->IsInstance<PrimValueNode>()) {
+ memo_[e] = arena_->make<Group>();
+ }
+ });
+
VisitExpr(func->body);
GroupMap group_map;
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py
b/tests/python/relax/test_transform_merge_composite_functions.py
index d5a3b1aa59..61df388c78 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -19,6 +19,7 @@ import pytest
import tvm
from tvm import relax
from tvm.script import relax as R
+from tvm.script import ir as I
@tvm.script.ir_module
@@ -1066,5 +1067,114 @@ def test_mixed_non_composite():
check(ModuleWithNonComposite, ModuleWithNonComposite_ref)
+def test_reshape():
+ # Verify that the non-CallNode input (shape in reshape) can be handled
properly.
+ @I.ir_module
+ class Module:
+ @R.function
+ def fused_relax_matmul(
+ lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512),
dtype="float32")
+ ) -> R.Tensor((1, 512), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 512), dtype="float32") = R.matmul(lv, lv1,
out_dtype="float32")
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_relax_reshape(
+ inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0:
R.Shape([1, 784])
+ ) -> R.Tensor((1, 784), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.reshape", "Primitive": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 784), dtype="float32") = R.reshape(inp_0,
param_0)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
+ linear_relu_stack_0_weight: R.Tensor((512, 784), dtype="float32"),
+ ) -> R.Tensor((1, 512), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ lv: R.Tensor((1, 784), dtype="float32") =
cls.fused_relax_reshape(
+ inp_0, R.shape([1, 784])
+ )
+ lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(
+ linear_relu_stack_0_weight, axes=None
+ )
+ lv_1: R.Tensor((1, 512), dtype="float32") =
cls.fused_relax_matmul(lv, lv1)
+ gv: R.Tensor((1, 512), dtype="float32") = lv_1
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def fused_relax_reshape_relax_matmul(
+ inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
+ param_0: R.Shape([1, 784]),
+ lv1: R.Tensor((784, 512), dtype="float32"),
+ ) -> R.Tensor((1, 512), dtype="float32"):
+ R.func_attr(
+ {
+ "Codegen": "tensorrt",
+ "Primitive": 1,
+ "global_symbol": "fused_relax_reshape_relax_matmul",
+ }
+ )
+ with R.dataflow():
+ # from tvm.script import relax as R
+
+ @R.function
+ def lv_1(
+ inp_0_1: R.Tensor((1, 1, 28, 28), dtype="float32"),
param_0_1: R.Shape([1, 784])
+ ) -> R.Tensor((1, 784), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.reshape", "Primitive":
1})
+ with R.dataflow():
+ gv: R.Tensor((1, 784), dtype="float32") =
R.reshape(inp_0_1, param_0_1)
+ R.output(gv)
+ return gv
+
+ lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0,
param_0)
+
+ @R.function
+ def lv1_1_1(
+ lv_2: R.Tensor((1, 784), dtype="float32"),
+ lv1_2: R.Tensor((784, 512), dtype="float32"),
+ ) -> R.Tensor((1, 512), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.matmul", "Primitive":
1})
+ with R.dataflow():
+ gv: R.Tensor((1, 512), dtype="float32") = R.matmul(
+ lv_2, lv1_2, out_dtype="float32"
+ )
+ R.output(gv)
+ return gv
+
+ lv_2: R.Tensor((1, 512), dtype="float32") = lv1_1_1(lv_1, lv1)
+ gv: R.Tensor((1, 512), dtype="float32") = lv_2
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
+ linear_relu_stack_0_weight: R.Tensor((512, 784), dtype="float32"),
+ ) -> R.Tensor((1, 512), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(
+ linear_relu_stack_0_weight, axes=None
+ )
+ gv: R.Tensor((1, 512), dtype="float32") =
cls.fused_relax_reshape_relax_matmul(
+ inp_0, R.shape([1, 784]), lv1
+ )
+ R.output(gv)
+ return gv
+
+ check(Module, Expected)
+
+
if __name__ == "__main__":
pytest.main([__file__])