This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new add45b5c1c [Unity] Make FuseOps work on a call_tir which has a 
ShapeExpr arg (#14553)
add45b5c1c is described below

commit add45b5c1cc7167c8648aeed2852321d4e860a59
Author: Hongyi Jin <[email protected]>
AuthorDate: Mon Apr 10 02:34:56 2023 -0400

    [Unity] Make FuseOps work on a call_tir which has a ShapeExpr arg (#14553)
    
    the pattern of the shapeexpr should be set to `kOpaque` in GraphCreator. 
Previously it was not handled and will run into an error
---
 src/relax/transform/fuse_ops.cc               |  2 +-
 tests/python/relax/test_transform_fuse_ops.py | 54 +++++++++++++++++++++++++++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index b01097aa1b..a49ae86267 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -252,7 +252,7 @@ class GraphCreator : public ExprVisitor {
     IndexedForwardGraph::Node* leaf_node = nullptr;
     if (it != graph_.node_map.end()) {
       leaf_node = it->second;
-    } else if (leaf_expr->IsInstance<ConstantNode>()) {
+    } else if (leaf_expr->IsInstance<ConstantNode>() || 
leaf_expr->IsInstance<ShapeExprNode>()) {
       leaf_node = CreateNode(leaf_expr.get());
       // Since we never fuse constants, the pattern of the constant is set to 
`kOpaque`.
       SetNodePattern(leaf_node, OpPatternKind::kOpaque);
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index cf8efb0587..285a78a30e 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1333,5 +1333,59 @@ def test_symbolic_shape_aware_fuse_2():
     _check(Before, Expected)
 
 
+def test_shape_expr_arg():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(s: R.Shape(["n"]), kv_cache: R.Object):
+            n = T.int64()
+            with R.dataflow():
+                lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+                lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), 
upper=True)
+                lv2 = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+                gv = R.call_packed(
+                    "vm.builtin.attention_kv_cache_view",
+                    kv_cache,
+                    R.shape([1 + n, 32, 128]),
+                    sinfo_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),),
+                )
+                R.output(gv, lv2)
+            return gv, lv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def fused_full_trilu_broadcast_to(
+            s: R.Shape(["n"]),
+        ) -> R.Tensor([1, 1, "n", "n"], "float32"):
+            R.func_attr({"Primitive": 1})
+            n = T.int64()
+            with R.dataflow():
+                lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
+                lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), 
upper=True)
+                gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(s: R.Shape(["n"]), kv_cache: R.Object):
+            cls = Expected
+            n = T.int64()
+            with R.dataflow():
+                lv: R.Tensor([1, 1, n, n], "float32") = 
cls.fused_full_trilu_broadcast_to(
+                    R.shape([n])
+                )
+                gv = R.call_packed(
+                    "vm.builtin.attention_kv_cache_view",
+                    kv_cache,
+                    R.shape([1 + n, 32, 128]),
+                    sinfo_args=(R.Tensor((1 + n, 32, 128), dtype="float32"),),
+                )
+                R.output(gv, lv)
+            return gv, lv
+
+    _check(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to