This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 07478afb0c [Relax] Fix issue in fuse concat ops by pattern (#18163)
07478afb0c is described below
commit 07478afb0ce196b99b3bec775eb9e130bb911023
Author: chenxinli <[email protected]>
AuthorDate: Sat Jul 26 23:33:03 2025 +0800
[Relax] Fix issue in fuse concat ops by pattern (#18163)
* [Relax] Fix issue in fuse concat ops by pattern
* fix lint
---
src/relax/transform/fuse_ops.cc | 8 ++-
.../relax/test_transform_fuse_ops_by_pattern.py | 73 ++++++++++++++++++++++
2 files changed, 80 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 0828e9c81c..434a7e7653 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -427,10 +427,16 @@ class FunctionCreator : public ExprMutator {
}
for (const Expr& arg : call->args) {
- CheckDefAndUpdateParam(arg);
if (GetStructInfoAs<TupleStructInfoNode>(arg) != nullptr) {
// The argument is fully referenced. Thus we remove it from the
mapping.
partially_used_tuple_params_.erase(arg.get());
+ const Tuple& tup_args = Downcast<Tuple>(arg);
+ for (const Expr& tup_arg : tup_args->fields) {
+ CheckDefAndUpdateParam(tup_arg);
+ ICHECK(GetStructInfoAs<TupleStructInfoNode>(tup_arg) ==
nullptr);
+ }
+ } else {
+ CheckDefAndUpdateParam(arg);
}
}
}
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 999879e751..2219c01ccb 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -26,6 +26,7 @@ from tvm.relax.dpl.pattern import (
is_tuple_get_item,
make_fused_bias_activation_pattern,
wildcard,
+ is_tuple,
)
from tvm.relax.transform import PatternCheckContext
from tvm.script import ir as I
@@ -1348,5 +1349,77 @@ def test_dataflow_inside_branch():
tvm.ir.assert_structural_equal(Expected, After)
+def test_concat():
+ @R.function
+ def func(x: R.Tensor((10,), "float32"), y: R.Tensor((10,), "float32")):
+ R.func_attr({"global_symbol": "main"})
+ with R.dataflow():
+ lv = R.abs(x)
+ lv1 = R.abs(y)
+ lv2 = R.concat([lv, lv1])
+ gv = R.nn.relu(lv2)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected1:
+ @R.function(private=True)
+ def fused_relax_abs_relax_abs_relax_concat(
+ x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((20,), dtype="float32"):
+ R.func_attr({"Composite": "x.concat_abs_abs", "Primitive": True})
+ with R.dataflow():
+ lv: R.Tensor((10,), dtype="float32") = R.abs(x)
+ lv1: R.Tensor((10,), dtype="float32") = R.abs(y)
+ gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1),
axis=0)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((20,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor(
+ (20,), dtype="float32"
+ ) = Expected1.fused_relax_abs_relax_abs_relax_concat(x, y)
+ gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv)
+ R.output(gv)
+ return gv
+
+ mod = tvm.IRModule({"main": func})
+ inp = is_tuple([is_op("relax.abs")(wildcard()),
is_op("relax.abs")(wildcard())])
+ pat_clip = is_op("relax.concat")(inp)
+
+ check(mod, [("x.concat_abs_abs", pat_clip)], Expected1)
+
+ @I.ir_module
+ class Expected2:
+ @R.function(private=True)
+ def fused_relax_concat(
+ lv: R.Tensor((10,), dtype="float32"), lv1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((20,), dtype="float32"):
+ R.func_attr({"Composite": "x.concat", "Primitive": True})
+ with R.dataflow():
+ gv: R.Tensor((20,), dtype="float32") = R.concat((lv, lv1),
axis=0)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((10,), dtype="float32"), y: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((20,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10,), dtype="float32") = R.abs(x)
+ lv1: R.Tensor((10,), dtype="float32") = R.abs(y)
+ lv_1: R.Tensor((20,), dtype="float32") =
Expected2.fused_relax_concat(lv, lv1)
+ gv: R.Tensor((20,), dtype="float32") = R.nn.relu(lv_1)
+ R.output(gv)
+ return gv
+
+ pat_clip = is_op("relax.concat")(wildcard())
+ check(mod, [("x.concat", pat_clip)], Expected2)
+
+
if __name__ == "__main__":
pytest.main([__file__])