This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 07478afb0c [Relax] Fix issue in fuse concat ops by pattern (#18163)
07478afb0c is described below

commit 07478afb0ce196b99b3bec775eb9e130bb911023
Author: chenxinli <[email protected]>
AuthorDate: Sat Jul 26 23:33:03 2025 +0800

    [Relax] Fix issue in fuse concat ops by pattern (#18163)
    
    * [Relax] Fix issue in fuse concat ops by pattern
    
    * fix lint
---
 src/relax/transform/fuse_ops.cc                    |  8 ++-
 .../relax/test_transform_fuse_ops_by_pattern.py    | 73 ++++++++++++++++++++++
 2 files changed, 80 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 0828e9c81c..434a7e7653 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -427,10 +427,16 @@ class FunctionCreator : public ExprMutator {
           }
 
           for (const Expr& arg : call->args) {
-            CheckDefAndUpdateParam(arg);
             if (GetStructInfoAs<TupleStructInfoNode>(arg) != nullptr) {
               // The argument is fully referenced. Thus we remove it from the 
mapping.
               partially_used_tuple_params_.erase(arg.get());
+              const Tuple& tup_args = Downcast<Tuple>(arg);
+              for (const Expr& tup_arg : tup_args->fields) {
+                CheckDefAndUpdateParam(tup_arg);
+                ICHECK(GetStructInfoAs<TupleStructInfoNode>(tup_arg) == 
nullptr);
+              }
+            } else {
+              CheckDefAndUpdateParam(arg);
             }
           }
         }
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 999879e751..2219c01ccb 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -26,6 +26,7 @@ from tvm.relax.dpl.pattern import (
     is_tuple_get_item,
     make_fused_bias_activation_pattern,
     wildcard,
+    is_tuple,
 )
 from tvm.relax.transform import PatternCheckContext
 from tvm.script import ir as I
@@ -1348,5 +1349,77 @@ def test_dataflow_inside_branch():
     tvm.ir.assert_structural_equal(Expected, After)
 
 
+def test_concat():
+    @R.function
+    def func(x: R.Tensor((10,), "float32"), y: R.Tensor((10,), "float32")):
+        R.func_attr({"global_symbol": "main"})
+        with R.dataflow():
+            lv = R.abs(x)
+            lv1 = R.abs(y)
+            lv2 = R.concat([lv, lv1])
+            gv = R.nn.relu(lv2)
+            R.output(gv)
+        return gv
+
+    @I.ir_module
+    class Expected1:
+        @R.function(private=True)
+        def fused_relax_abs_relax_abs_relax_concat(
+            x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((20,), dtype="float32"):
+            R.func_attr({"Composite": "x.concat_abs_abs", "Primitive": True})
+            with R.dataflow():
+                lv: R.Tensor((10,), dtype="float32") = R.abs(x)
+                lv1: R.Tensor((10,), dtype="float32") = R.abs(y)
+                gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), 
axis=0)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((20,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor(
+                    (20,), dtype="float32"
+                ) = Expected1.fused_relax_abs_relax_abs_relax_concat(x, y)
+                gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv)
+                R.output(gv)
+            return gv
+
+    mod = tvm.IRModule({"main": func})
+    inp = is_tuple([is_op("relax.abs")(wildcard()), 
is_op("relax.abs")(wildcard())])
+    pat_clip = is_op("relax.concat")(inp)
+
+    check(mod, [("x.concat_abs_abs", pat_clip)], Expected1)
+
+    @I.ir_module
+    class Expected2:
+        @R.function(private=True)
+        def fused_relax_concat(
+            lv: R.Tensor((10,), dtype="float32"), lv1: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((20,), dtype="float32"):
+            R.func_attr({"Composite": "x.concat", "Primitive": True})
+            with R.dataflow():
+                gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1), 
axis=0)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((20,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((10,), dtype="float32") = R.abs(x)
+                lv1: R.Tensor((10,), dtype="float32") = R.abs(y)
+                lv_1: R.Tensor((20,), dtype="float32") = 
Expected2.fused_relax_concat(lv, lv1)
+                gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv_1)
+                R.output(gv)
+            return gv
+
+    pat_clip = is_op("relax.concat")(wildcard())
+    check(mod, [("x.concat", pat_clip)], Expected2)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to