This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 94fbfd254c [Unity][Fix][Pass] Fix FuseOps for lack graph edges (#14058)
94fbfd254c is described below

commit 94fbfd254cfcc3b5721acf67cff4677e16ea0774
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Feb 20 23:37:22 2023 -0500

    [Unity][Fix][Pass] Fix FuseOps for lack graph edges (#14058)
    
    This PR fixes a mistake of #14044. In #14044, in VisitLeaf of graph
    construction of FuseOps, we first check if the input node is Leaf and
    then check if it is Tuple. This is not right: as Tuple is not
    categorized as one leaf node, when the input node is a Tuple, the
    function will return since the input is not a LeafNode. And the check
    for Tuple will thereby never holds.
    
    It is quite interesting that our existing unit tests fail to filter this
    mistake out. I add a regression test for this case, which can ensure
    that the tuple is always visited.
---
 src/relax/transform/fuse_ops.cc               |  9 +++++----
 tests/python/relax/test_transform_fuse_ops.py | 22 +++++++++++++++++++++-
 2 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 3b78274cec..813c0c8f03 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -234,10 +234,6 @@ class GraphCreator : public ExprVisitor {
   void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* 
binding_var_node,
                  const OpPatternKind& pattern) {
     ICHECK_NOTNULL(binding_var_node);
-    if (!leaf_expr->IsInstance<LeafExprNode>()) {
-      // Skip GlobalVar, ExternFunc, OpNode.
-      return;
-    }
 
     // Recursive visit if it's Tuple
     if (const auto* tuple = leaf_expr.as<TupleNode>()) {
@@ -247,6 +243,11 @@ class GraphCreator : public ExprVisitor {
       return;
     }
 
+    if (!leaf_expr->IsInstance<LeafExprNode>()) {
+      // Skip GlobalVar, ExternFunc, OpNode.
+      return;
+    }
+
     auto it = graph_.node_map.find(leaf_expr.get());
     IndexedForwardGraph::Node* leaf_node = nullptr;
     if (it != graph_.node_map.end()) {
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index 6fad4f8165..d38e582981 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -18,7 +18,7 @@
 import tvm
 import tvm.testing
 from tvm import relax, topi
-from tvm.script import ir as I, relax as R
+from tvm.script import ir as I, relax as R, tir as T
 
 
 def _check(mod_actual, mod_expected):
@@ -834,5 +834,25 @@ def test_skip_call_dps_packed():
     _check(Module, Module)
 
 
+def test_edge_with_call_dps_packed():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")):
+            with R.dataflow():
+                a = R.call_tir(exp, (x,), out_sinfo=R.Tensor((2, 3), 
"float32"))
+                b = R.call_tir(exp, (a,), out_sinfo=R.Tensor((2, 3), 
"float32"))
+                c = R.call_tir("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), 
"float32"))
+                R.output(b, c)
+            return R.tuple(b, c)
+
+        @T.prim_func
+        def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), 
"float32")):
+            T.evaluate(0)
+
+    # FuseOps should does no change to it.
+    _check(Module, Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to