This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 98008c2d67 [Fix][TVMScript] Fix `LetStmt` printing logic (#13900)
98008c2d67 is described below
commit 98008c2d676dd14b43ec97050f935217c014dc49
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Feb 3 10:34:36 2023 -0800
[Fix][TVMScript] Fix `LetStmt` printing logic (#13900)
This PR is the bug fix reported in #13892. Initially, we mix the logic of
`LetStmt` docsifying method with and without concise scoping. For example, in
```python
x = T.var("int32")
with T.let(x, 0):
```
`x` in the `LetStmt` works as a right value, while in
```python
x: T.int32 = 0
```
`x` in the `LetStmt` works as a left value as result.
Our old logic mixed them together to generate the wrong code for the first
case.
Meanwhile, during the fix, we found another bug in concise scoping check.
For example, we have
```python
x = T.var("int32")
y = T.var("int32")
with T.let(x, y):
with T.let(y, 0):
```
here we should not output
```python
x = T.var("int32")
y = T.var("int32")
with T.let(x, y):
y: int32 = 0
```
becase this will define a new `y_1: int32 = 0` indeed, due the the variable
shadowing logic of the parser, which is different from the `y` we define and
refer to.
Our concise scoping `v: ... = ...` should launch if and only if the `v` is
never defined before.
Otherwise, we use `with T.let(v, ...):` instead.
---
src/script/printer/tir/stmt.cc | 18 ++++++++------
.../python/unittest/test_tvmscript_printer_tir.py | 1 +
tests/python/unittest/test_tvmscript_roundtrip.py | 28 ++++++++++++++++++++++
3 files changed, 40 insertions(+), 7 deletions(-)
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 2820f9ba63..7556f820df 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -57,13 +57,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p,
IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
- ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
- With<TIRFrame> f(d, stmt);
- ExprDoc lhs = d->IsVarDefined(stmt->var) ?
d->GetVarDoc(stmt->var).value()
- : DefineVar(stmt->var, *f, d);
- AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
- Array<StmtDoc>* stmts = &(*f)->stmts;
- if (concise) {
+ if (concise && !d->IsVarDefined(stmt->var)) {
+ ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
+ With<TIRFrame> f(d, stmt);
+ ExprDoc lhs = DefineVar(stmt->var, *f, d);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ Array<StmtDoc>* stmts = &(*f)->stmts;
Type type = stmt->var->type_annotation;
Optional<ExprDoc> type_doc =
d->AsDoc<ExprDoc>(type, p->Attr("var")->Attr("type_annotation"));
@@ -75,6 +74,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
return StmtBlockDoc(*stmts);
} else {
+ ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
+ ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
+ With<TIRFrame> f(d, stmt);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ Array<StmtDoc>* stmts = &(*f)->stmts;
rhs = TIR(d, "let")->Call({lhs, rhs});
return ScopeDoc(NullOpt, rhs, *stmts);
}
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index 49a33cd0f0..6f96b3a3dd 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -254,6 +254,7 @@ def test_let_stmt():
_assert_print(
obj,
"""
+v = T.var("float32")
with T.let(v, T.float32(10)):
T.evaluate(0)
""",
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 4300c4bbad..f52b488fef 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3543,6 +3543,32 @@ def intrinsic_pow():
return func
+def let_stmt_var():
+ @T.prim_func
+ def func():
+ x = T.var("int32")
+ y = T.var("int32")
+ with T.let(x, 0):
+ with T.let(y, 0):
+ T.evaluate(0)
+ T.evaluate(0)
+
+ return func
+
+
+def let_stmt_value():
+ @T.prim_func
+ def func():
+ x = T.var("int32")
+ y = T.var("int32")
+ with T.let(x, y):
+ with T.let(y, 0):
+ T.evaluate(0)
+ T.evaluate(0)
+
+ return func
+
+
ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
@@ -3601,6 +3627,8 @@ ir_generator = tvm.testing.parameter(
*nested_boolean_expressions(),
multi_env_threads,
intrinsic_pow,
+ let_stmt_var,
+ let_stmt_value,
)