This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 465d691e66 [Unity] Ignore R.ExternFunc in EliminateCommonSubexpr
(#15900)
465d691e66 is described below
commit 465d691e660d2a0ebe895c29f586b4db578a1189
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Oct 11 04:17:29 2023 -0500
[Unity] Ignore R.ExternFunc in EliminateCommonSubexpr (#15900)
Prior to this commit, `relax::ExternFunc` nodes would be de-duplicated
as part of the `EliminateCommonSubexpr` pass. This commit instead
ignores the `relax::ExternFunc` nodes, retaining the in-line
definitions.
---
src/relax/transform/eliminate_common_subexpr.cc | 2 +-
tests/python/relax/test_transform_cse.py | 14 ++++++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index fa90d41933..842470c463 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -94,7 +94,7 @@ class SubexprCounter : public ExprVisitor {
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
- e->IsInstance<ShapeExprNode>() ||
+ e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
(e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
// also if e has an impure subexpression, we will not deduplicate it
if (!impurity_detector_.Detect(e)) {
diff --git a/tests/python/relax/test_transform_cse.py
b/tests/python/relax/test_transform_cse.py
index cf66ae3c1c..d69ec61b5c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -276,5 +276,19 @@ def test_do_not_eliminate_shape_expr():
verify(Before, Expected)
+def test_do_not_eliminate_extern_func():
+ @I.ir_module
+ class Before:
+ @R.function(pure=False)
+ def foo(x: R.Tensor((2, 3), dtype="float32")):
+ y = R.call_packed("extern_func_name", x, sinfo_args=R.Tensor([2,
3]))
+ z = R.call_packed("extern_func_name", y, sinfo_args=R.Tensor([2,
3]))
+ return z
+
+ Expected = Before
+
+ verify(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()