This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 465d691e66 [Unity] Ignore R.ExternFunc in EliminateCommonSubexpr 
(#15900)
465d691e66 is described below

commit 465d691e660d2a0ebe895c29f586b4db578a1189
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Oct 11 04:17:29 2023 -0500

    [Unity] Ignore R.ExternFunc in EliminateCommonSubexpr (#15900)
    
    Prior to this commit, `relax::ExternFunc` nodes would be de-duplicated
    as part of the `EliminateCommonSubexpr` pass.  This commit instead
    ignores the `relax::ExternFunc` nodes, retaining the in-line
    definitions.
---
 src/relax/transform/eliminate_common_subexpr.cc |  2 +-
 tests/python/relax/test_transform_cse.py        | 14 ++++++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index fa90d41933..842470c463 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -94,7 +94,7 @@ class SubexprCounter : public ExprVisitor {
     if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
           e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
           e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
-          e->IsInstance<ShapeExprNode>() ||
+          e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
           (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
       // also if e has an impure subexpression, we will not deduplicate it
       if (!impurity_detector_.Detect(e)) {
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index cf66ae3c1c..d69ec61b5c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -276,5 +276,19 @@ def test_do_not_eliminate_shape_expr():
     verify(Before, Expected)
 
 
+def test_do_not_eliminate_extern_func():
+    @I.ir_module
+    class Before:
+        @R.function(pure=False)
+        def foo(x: R.Tensor((2, 3), dtype="float32")):
+            y = R.call_packed("extern_func_name", x, sinfo_args=R.Tensor([2, 
3]))
+            z = R.call_packed("extern_func_name", y, sinfo_args=R.Tensor([2, 
3]))
+            return z
+
+    Expected = Before
+
+    verify(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to