This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 94fbfd254c [Unity][Fix][Pass] Fix FuseOps for lack graph edges (#14058)
94fbfd254c is described below
commit 94fbfd254cfcc3b5721acf67cff4677e16ea0774
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Feb 20 23:37:22 2023 -0500
[Unity][Fix][Pass] Fix FuseOps for lack graph edges (#14058)
This PR fixes a mistake of #14044. In #14044, in VisitLeaf of graph
construction of FuseOps, we first check if the input node is Leaf and
then check if it is Tuple. This is not right: as Tuple is not
categorized as one leaf node, when the input node is a Tuple, the
function will return since the input is not a LeafNode. And the check
for Tuple will thereby never holds.
It is quite interesting that our existing unit tests fail to filter this
mistake out. I add a regression test for this case, which can ensure
that the tuple is always visited.
---
src/relax/transform/fuse_ops.cc | 9 +++++----
tests/python/relax/test_transform_fuse_ops.py | 22 +++++++++++++++++++++-
2 files changed, 26 insertions(+), 5 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 3b78274cec..813c0c8f03 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -234,10 +234,6 @@ class GraphCreator : public ExprVisitor {
void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node*
binding_var_node,
const OpPatternKind& pattern) {
ICHECK_NOTNULL(binding_var_node);
- if (!leaf_expr->IsInstance<LeafExprNode>()) {
- // Skip GlobalVar, ExternFunc, OpNode.
- return;
- }
// Recursive visit if it's Tuple
if (const auto* tuple = leaf_expr.as<TupleNode>()) {
@@ -247,6 +243,11 @@ class GraphCreator : public ExprVisitor {
return;
}
+ if (!leaf_expr->IsInstance<LeafExprNode>()) {
+ // Skip GlobalVar, ExternFunc, OpNode.
+ return;
+ }
+
auto it = graph_.node_map.find(leaf_expr.get());
IndexedForwardGraph::Node* leaf_node = nullptr;
if (it != graph_.node_map.end()) {
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index 6fad4f8165..d38e582981 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -18,7 +18,7 @@
import tvm
import tvm.testing
from tvm import relax, topi
-from tvm.script import ir as I, relax as R
+from tvm.script import ir as I, relax as R, tir as T
def _check(mod_actual, mod_expected):
@@ -834,5 +834,25 @@ def test_skip_call_dps_packed():
_check(Module, Module)
+def test_edge_with_call_dps_packed():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32")):
+ with R.dataflow():
+ a = R.call_tir(exp, (x,), out_sinfo=R.Tensor((2, 3),
"float32"))
+ b = R.call_tir(exp, (a,), out_sinfo=R.Tensor((2, 3),
"float32"))
+ c = R.call_tir("packed_dps", (a,), out_sinfo=R.Tensor((2, 3),
"float32"))
+ R.output(b, c)
+ return R.tuple(b, c)
+
+ @T.prim_func
+ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ T.evaluate(0)
+
+ # FuseOps should does no change to it.
+ _check(Module, Module)
+
+
if __name__ == "__main__":
tvm.testing.main()