This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new add45b5c1c [Unity] Make FuseOps work on a call_tir which has a
ShapeExpr arg (#14553)
add45b5c1c is described below
commit add45b5c1cc7167c8648aeed2852321d4e860a59
Author: Hongyi Jin <[email protected]>
AuthorDate: Mon Apr 10 02:34:56 2023 -0400
[Unity] Make FuseOps work on a call_tir which has a ShapeExpr arg (#14553)
the pattern of the shapeexpr should be set to `kOpaque` in GraphCreator.
Previously it was not handled and will run into an error
---
src/relax/transform/fuse_ops.cc | 2 +-
tests/python/relax/test_transform_fuse_ops.py | 54 +++++++++++++++++++++++++++
2 files changed, 55 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index b01097aa1b..a49ae86267 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -252,7 +252,7 @@ class GraphCreator : public ExprVisitor {
IndexedForwardGraph::Node* leaf_node = nullptr;
if (it != graph_.node_map.end()) {
leaf_node = it->second;
- } else if (leaf_expr->IsInstance<ConstantNode>()) {
+ } else if (leaf_expr->IsInstance<ConstantNode>() ||
leaf_expr->IsInstance<ShapeExprNode>()) {
leaf_node = CreateNode(leaf_expr.get());
// Since we never fuse constants, the pattern of the constant is set to
`kOpaque`.
SetNodePattern(leaf_node, OpPatternKind::kOpaque);
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index cf8efb0587..285a78a30e 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1333,5 +1333,59 @@ def test_symbolic_shape_aware_fuse_2():
_check(Before, Expected)
+def test_shape_expr_arg():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(s: R.Shape(["n"]), kv_cache: R.Object):
+ n = T.int64()
+ with R.dataflow():
+ lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+ lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"),
upper=True)
+ lv2 = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+ gv = R.call_packed(
+ "vm.builtin.attention_kv_cache_view",
+ kv_cache,
+ R.shape([1 + n, 32, 128]),
+ sinfo_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),),
+ )
+ R.output(gv, lv2)
+ return gv, lv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def fused_full_trilu_broadcast_to(
+ s: R.Shape(["n"]),
+ ) -> R.Tensor([1, 1, "n", "n"], "float32"):
+ R.func_attr({"Primitive": 1})
+ n = T.int64()
+ with R.dataflow():
+ lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+ lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"),
upper=True)
+ gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(s: R.Shape(["n"]), kv_cache: R.Object):
+ cls = Expected
+ n = T.int64()
+ with R.dataflow():
+ lv: R.Tensor([1, 1, n, n], "float32") =
cls.fused_full_trilu_broadcast_to(
+ R.shape([n])
+ )
+ gv = R.call_packed(
+ "vm.builtin.attention_kv_cache_view",
+ kv_cache,
+ R.shape([1 + n, 32, 128]),
+ sinfo_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),),
+ )
+ R.output(gv, lv)
+ return gv, lv
+
+ _check(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()