This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1d89071863 [TVMScript] More concise `T.allocate` syntax printing
(#13830)
1d89071863 is described below
commit 1d89071863abc6af615a288aac61a919de02f1e6
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Jan 23 23:13:35 2023 -0800
[TVMScript] More concise `T.allocate` syntax printing (#13830)
This PR is the follow up of #13813. We simplify the printing output of
`T.allocate` with `T.decl_buffer`. For example, we have a code snippet as
```python
buffer_data = T.allocate(...)
buffer = T.decl_buffer(..., data=buffer_data)
T.evaluate(buffer_data)
```
Originally, we skip the `T.allocate` only if the var `buffer_data` defined
by `T.allocate` is used only once by the following `T.decl_buffer`. This was
due to the limitation of the old printer design.
But in the new printer, we may automatically replace the `buffer_data` with
`buffer.data` if skipping the definition of `buffer_data`. We are able to link
all `buffer_data` usages together. So the new output result will be like
```python
buffer = T.decl_buffer(...)
T.evaluate(buffer.data)
```
---
src/script/printer/tir/stmt.cc | 4 +---
tests/python/unittest/test_tvmscript_printer_tir.py | 7 +++----
2 files changed, 4 insertions(+), 7 deletions(-)
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 7c8d44c10e..acdfd7da47 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -175,9 +175,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Allocate>( //
"", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
- OccurrenceCounter counter(stmt->buffer_var.get());
- counter(stmt->body);
- if (counter.count == 1 && IsAllocateDeclBufferPattern(stmt.get())) {
+ if (IsAllocateDeclBufferPattern(stmt.get())) {
return d->AsDoc(stmt->body, stmt_p->Attr("body"));
}
Array<ExprDoc> args;
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index d57d104670..c73ae29193 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -343,7 +343,7 @@ with T.decl_buffer((128, 128)) as buffer:
)
-def test_allocate_with_decl_buffer_no_sugar_multi_usage():
+def test_allocate_with_decl_buffer_sugar_multi_usage():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([128, 128], "float32", data=buffer_data) as
buffer:
@@ -352,9 +352,8 @@ def test_allocate_with_decl_buffer_no_sugar_multi_usage():
_assert_print(
obj,
"""
-with T.allocate([128, 128], "float32", "global") as v:
- buffer = T.decl_buffer((128, 128), data=v)
- T.evaluate(v)
+with T.decl_buffer((128, 128)) as buffer:
+ T.evaluate(buffer.data)
""",
)