This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f56a95b87 [TVMScript] Use new variable frame in If/Then/Else (#14250)
3f56a95b87 is described below

commit 3f56a95b872ca1078b209db9d8d368c8a29c0923
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Mar 23 15:54:48 2023 -0500

    [TVMScript] Use new variable frame in If/Then/Else (#14250)
    
    Previously, while TVMScript introduces a new scope for other
    contexts (e.g. `for`, `while`, `with`, etc), the `if` and `else`
    blocks did not introduce a new scope.  This caused erroneous parsing
    errors if the `if` and `else` blocks each contained a variable with
    the same name.  Added a `self.var_table.with_frame()` context resolves
    this issue.
---
 python/tvm/script/parser/tir/parser.py            |  6 ++++--
 tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/python/tvm/script/parser/tir/parser.py 
b/python/tvm/script/parser/tir/parser.py
index 7ef2039b9d..8a067267a3 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -433,10 +433,12 @@ def visit_if(self: Parser, node: doc.If) -> None:
     with self.var_table.with_frame():
         with T.If(self.eval_expr(node.test)):
             with T.Then():
-                self.visit_body(node.body)
+                with self.var_table.with_frame():
+                    self.visit_body(node.body)
             if node.orelse:
                 with T.Else():
-                    self.visit_body(node.orelse)
+                    with self.var_table.with_frame():
+                        self.visit_body(node.orelse)
 
 
 @dispatch.register(token="tir", type_name="Assert")
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 6f07b6a75a..990fe285eb 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3623,6 +3623,19 @@ def merge_shape_var_def():
     return main
 
 
+def if_then_else_var():
+    @T.prim_func
+    def main(n: T.int32):
+        if n == 0:
+            x = 5
+            T.evaluate(x)
+        else:
+            x = 10
+            T.evaluate(x)
+
+    return main
+
+
 def tvm_shfl_builtins():
     @T.prim_func
     def func(
@@ -3740,6 +3753,7 @@ ir_generator = tvm.testing.parameter(
     let_stmt_value,
     string_stride,
     merge_shape_var_def,
+    if_then_else_var,
     tvm_shfl_builtins,
 )
 

Reply via email to