This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f56a95b87 [TVMScript] Use new variable frame in If/Then/Else (#14250)
3f56a95b87 is described below
commit 3f56a95b872ca1078b209db9d8d368c8a29c0923
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Mar 23 15:54:48 2023 -0500
[TVMScript] Use new variable frame in If/Then/Else (#14250)
Previously, while TVMScript introduces a new scope for other
contexts (e.g. `for`, `while`, `with`, etc), the `if` and `else`
blocks did not introduce a new scope. This caused erroneous parsing
errors if the `if` and `else` blocks each contained a variable with
the same name. Added a `self.var_table.with_frame()` context resolves
this issue.
---
python/tvm/script/parser/tir/parser.py | 6 ++++--
tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/python/tvm/script/parser/tir/parser.py
b/python/tvm/script/parser/tir/parser.py
index 7ef2039b9d..8a067267a3 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -433,10 +433,12 @@ def visit_if(self: Parser, node: doc.If) -> None:
with self.var_table.with_frame():
with T.If(self.eval_expr(node.test)):
with T.Then():
- self.visit_body(node.body)
+ with self.var_table.with_frame():
+ self.visit_body(node.body)
if node.orelse:
with T.Else():
- self.visit_body(node.orelse)
+ with self.var_table.with_frame():
+ self.visit_body(node.orelse)
@dispatch.register(token="tir", type_name="Assert")
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 6f07b6a75a..990fe285eb 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3623,6 +3623,19 @@ def merge_shape_var_def():
return main
+def if_then_else_var():
+ @T.prim_func
+ def main(n: T.int32):
+ if n == 0:
+ x = 5
+ T.evaluate(x)
+ else:
+ x = 10
+ T.evaluate(x)
+
+ return main
+
+
def tvm_shfl_builtins():
@T.prim_func
def func(
@@ -3740,6 +3753,7 @@ ir_generator = tvm.testing.parameter(
let_stmt_value,
string_stride,
merge_shape_var_def,
+ if_then_else_var,
tvm_shfl_builtins,
)