This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9d27a6aeb8 Fix super() visit function in PyExprVisitor and 
PyExprMutator (#15189)
9d27a6aeb8 is described below

commit 9d27a6aeb8b2931c1611d4d78b25cc2a3fd9ef13
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Jun 30 09:55:03 2023 -0700

    Fix super() visit function in PyExprVisitor and PyExprMutator (#15189)
    
    As we discussed in 
https://discuss.tvm.apache.org/t/recursive-visiting-in-pyexprmutator/15224,
    
    It will run into an infinite recursion if we call the super() visit 
function in PyExprVisitor and PyExprMutator.
    
    ```python
    super().visit_binding_block_()
    super().visit_dataflow_block_()
    super().visit_var_def_()
    super().visit_dataflow_var_def_()
    ```
    
    This PR fixes the behavior.
---
 src/relax/ir/py_expr_functor.cc         | 32 ++++++++++++--
 tests/python/relax/test_expr_functor.py | 75 +++++++++++++++++++++++++++++++--
 2 files changed, 99 insertions(+), 8 deletions(-)

diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc
index c205e2feb4..a7ac245610 100644
--- a/src/relax/ir/py_expr_functor.cc
+++ b/src/relax/ir/py_expr_functor.cc
@@ -568,12 +568,24 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding")
 
 TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock")
     .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) {
-      visitor->ExprVisitor::VisitBindingBlock(block);
+      if (const auto* ptr = block.as<DataflowBlockNode>()) {
+        visitor->ExprVisitor::VisitBindingBlock_(ptr);
+      } else if (const auto* ptr = block.as<BindingBlockNode>()) {
+        visitor->ExprVisitor::VisitBindingBlock_(ptr);
+      } else {
+        LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
+      }
     });
 
 TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef")
     .set_body_typed([](PyExprVisitor visitor, const Var& var) {
-      visitor->ExprVisitor::VisitVarDef(var);
+      if (const auto* node = var.as<DataflowVarNode>()) {
+        visitor->ExprVisitor::VisitVarDef_(node);
+      } else if (const auto* node = var.as<VarNode>()) {
+        visitor->ExprVisitor::VisitVarDef_(node);
+      } else {
+        LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey();
+      }
     });
 
 TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan")
@@ -621,12 +633,24 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding")
 
 TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock")
     .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) {
-      return mutator->ExprMutator::VisitBindingBlock(block);
+      if (const auto* node = block.as<DataflowBlockNode>()) {
+        return mutator->ExprMutator::VisitBindingBlock_(node);
+      } else if (const auto* node = block.as<BindingBlockNode>()) {
+        return mutator->ExprMutator::VisitBindingBlock_(node);
+      } else {
+        LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey();
+      }
     });
 
 TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef")
     .set_body_typed([](PyExprMutator mutator, const Var& var) {
-      return mutator->ExprMutator::VisitVarDef(var);
+      if (const auto* node = var.as<DataflowVarNode>()) {
+        return mutator->ExprMutator::VisitVarDef_(node);
+      } else if (const auto* node = var.as<VarNode>()) {
+        return mutator->ExprMutator::VisitVarDef_(node);
+      } else {
+        LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey();
+      }
     });
 
 TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder")
diff --git a/tests/python/relax/test_expr_functor.py 
b/tests/python/relax/test_expr_functor.py
index 0c35e18f5f..c18ab3e6f6 100644
--- a/tests/python/relax/test_expr_functor.py
+++ b/tests/python/relax/test_expr_functor.py
@@ -678,7 +678,10 @@ def test_wrong_inherit():
 @R.function
 def dummy(x: R.Tensor((10, 10))):
     lv = R.add(x, R.const(1))
-    return lv
+    with R.dataflow():
+        gv = lv
+        R.output(gv)
+    return gv
 
 
 def test_call_visitor_super():
@@ -688,6 +691,14 @@ def test_call_visitor_super():
             super().__init__()
             self.log = ASTLog()
 
+        def visit_binding_block_(self, block: relax.BindingBlock) -> None:
+            self.log.add("BindingBlock")
+            super().visit_binding_block_(block)
+
+        def visit_dataflow_block_(self, block: DataflowBlock) -> None:
+            self.log.add("DataflowBlock")
+            super().visit_dataflow_block_(block)
+
         def visit_var_binding_(self, binding: relax.VarBinding) -> None:
             self.log.add("VarBinding")
             super().visit_var_binding_(binding)
@@ -696,6 +707,14 @@ def test_call_visitor_super():
             self.log.add("InternalCall")
             super().visit_call_(op)  # call PyExprVisitor.visit_call_
 
+        def visit_var_def_(self, var: Var) -> None:
+            self.log.add("VarDef")
+            super().visit_var_def_(var)
+
+        def visit_dataflow_var_def_(self, var: Var) -> None:
+            self.log.add("DataflowVarDef")
+            super().visit_dataflow_var_def_(var)
+
         def visit_var_(self, op: Var) -> None:
             self.log.add("Var")
 
@@ -719,7 +738,23 @@ def test_call_visitor_super():
 
     lv = LeafVisitor()
     lv.visit_expr(dummy)
-    assert str(lv.log) == "\n".join(["VarBinding", "LeafCall", "InternalCall", 
"Op", "Var", "Var"])
+    assert str(lv.log) == "\n".join(
+        [
+            "VarDef",
+            "BindingBlock",
+            "VarBinding",
+            "LeafCall",
+            "InternalCall",
+            "Op",
+            "Var",
+            "VarDef",
+            "DataflowBlock",
+            "VarBinding",
+            "Var",
+            "VarDef",
+            "Var",
+        ]
+    )
 
 
 def test_call_mutator_super():
@@ -729,14 +764,30 @@ def test_call_mutator_super():
             super().__init__()
             self.log = ASTLog()
 
+        def visit_binding_block_(self, block: relax.BindingBlock) -> None:
+            self.log.add("BindingBlock")
+            return super().visit_binding_block_(block)
+
+        def visit_dataflow_block_(self, block: DataflowBlock) -> None:
+            self.log.add("DataflowBlock")
+            return super().visit_dataflow_block_(block)
+
         def visit_var_binding_(self, binding: relax.VarBinding) -> None:
             self.log.add("VarBinding")
-            super().visit_var_binding_(binding)
+            return super().visit_var_binding_(binding)
 
         def visit_call_(self, op: Call) -> None:
             self.log.add("InternalCall")
             return super().visit_call_(op)  # call PyExprMutator.visit_call_
 
+        def visit_var_def_(self, var: Var) -> None:
+            self.log.add("VarDef")
+            return super().visit_var_def_(var)
+
+        def visit_dataflow_var_def_(self, var: Var) -> None:
+            self.log.add("DataflowVarDef")
+            return super().visit_dataflow_var_def_(var)
+
         def visit_var_(self, op: Var) -> None:
             self.log.add("Var")
             return super().visit_var_(op)  # call PyExprMutator.visit_var_
@@ -762,7 +813,23 @@ def test_call_mutator_super():
 
     lm = LeafMutator()
     lm.visit_expr(dummy)
-    assert str(lm.log) == "\n".join(["VarBinding", "LeafCall", "InternalCall", 
"Op", "Var", "Var"])
+    assert str(lm.log) == "\n".join(
+        [
+            "VarDef",
+            "BindingBlock",
+            "VarBinding",
+            "LeafCall",
+            "InternalCall",
+            "Op",
+            "Var",
+            "VarDef",
+            "DataflowBlock",
+            "VarBinding",
+            "Var",
+            "VarDef",
+            "Var",
+        ]
+    )
 
 
 if __name__ == "__main__":

Reply via email to