This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ac9fb98857 [TVMScript] Implicit root block syntax sugar for TVMScript
printer (#13819)
ac9fb98857 is described below
commit ac9fb98857f68d5406c902db209548a1a6a1e9c1
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Jan 20 21:25:55 2023 -0800
[TVMScript] Implicit root block syntax sugar for TVMScript printer (#13819)
This PR implements the syntax sugar of implicit root block for new
TVMScript printer. This syntax sugar will skip the `T.block("root")`, when the
root block realize is simple and we shall reconstruct that root block via
`tvm::tir::ScriptComplete` when roundtripping. For example, it will change
```python
@T.prim_func
def root_block_explicitly():
with T.block("root"):
a = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block():
T.evaluate(0)
```
into
```python
@T.prim_func
def main():
a = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block(""):
T.evaluate(0)
```
---
src/script/printer/tir/function.cc | 35 ++++++++++++++-
.../python/unittest/test_tvmscript_printer_tir.py | 52 +++++++++++++++++-----
2 files changed, 74 insertions(+), 13 deletions(-)
diff --git a/src/script/printer/tir/function.cc
b/src/script/printer/tir/function.cc
index 40957fcffa..ea7d56e165 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -131,7 +131,40 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
}
// Step 4. Handle `func->body`
- AsDocBody(func->body, p->Attr("body"), frame->get(), d);
+ Optional<tir::Block> implicit_root_block = [&]() -> Optional<tir::Block>
{
+ const tir::BlockRealizeNode* root_block_realize =
func->body.as<tir::BlockRealizeNode>();
+ if (root_block_realize && !root_block_realize->iter_values.size() &&
+ tir::is_one(root_block_realize->predicate)) {
+ tir::Block root_block = root_block_realize->block;
+ if (!root_block->annotations.size() &&
!root_block->match_buffers.size() &&
+ !root_block->reads.size() && !root_block->writes.size() &&
+ !root_block->init.defined()) {
+ const tir::BlockRealizeNode* block_realize =
+ root_block->body.as<tir::BlockRealizeNode>();
+ if (root_block->alloc_buffers.size() ||
+ (block_realize && block_realize->block->iter_vars.size()) ||
+ (!block_realize &&
tir::ContainsNode<tir::BlockRealizeNode>(root_block->body))) {
+ return root_block;
+ }
+ }
+ }
+ return NullOpt;
+ }();
+ if (implicit_root_block) {
+ tir::Block root_block = implicit_root_block.value();
+ ObjectPath root_block_p = p->Attr("body")->Attr("body");
+ // Handle root block `alloc_buffer`
+ for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
+ tir::Buffer buffer = root_block->alloc_buffers[i];
+ ObjectPath buffer_p =
root_block_p->Attr("alloc_buffers")->ArrayIndex(i);
+ IdDoc lhs = DefineBuffer(buffer, *frame, d);
+ ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p,
*frame, d);
+ (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ }
+ AsDocBody(root_block->body, root_block_p->Attr("body"), frame->get(),
d);
+ } else {
+ AsDocBody(func->body, p->Attr("body"), frame->get(), d);
+ }
Optional<ExprDoc> ret_type = NullOpt;
if (func->ret_type.defined()) {
const auto* as_tuple = func->ret_type.as<TupleTypeNode>();
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index 5d86a88608..d57d104670 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -717,21 +717,49 @@ def test_remap():
expected_output = """@T.prim_func
def main():
- with T.block("root"):
- T.reads()
- T.writes()
- for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
- with T.block("update"):
- v0 = T.axis.spatial(128, i0 + 1)
- v1, v2 = T.axis.remap("SR", [i1, i2])
- v3 = T.axis.spatial(128, i3 - 1)
- v4, v5 = T.axis.remap("RS", [i4, i5])
- T.reads()
- T.writes()
- T.evaluate(0)"""
+ for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+ with T.block("update"):
+ v0 = T.axis.spatial(128, i0 + 1)
+ v1, v2 = T.axis.remap("SR", [i1, i2])
+ v3 = T.axis.spatial(128, i3 - 1)
+ v4, v5 = T.axis.remap("RS", [i4, i5])
+ T.reads()
+ T.writes()
+ T.evaluate(0)"""
_assert_print(block_with_remap_explicitly, expected_output)
_assert_print(block_with_remap_implicitly, expected_output)
+def test_root_block():
+ from tvm.script import tir as T
+
+ @T.prim_func
+ def root_block_implicitly():
+ a = T.alloc_buffer([128, 128])
+ for i, j in T.grid(128, 128):
+ with T.block():
+ T.evaluate(0)
+
+ @T.prim_func
+ def root_block_explicitly():
+ with T.block("root"):
+ a = T.alloc_buffer([128, 128])
+ for i, j in T.grid(128, 128):
+ with T.block():
+ T.evaluate(0)
+
+ expected_output = """@T.prim_func
+def main():
+ a = T.alloc_buffer((128, 128))
+ for i, j in T.grid(128, 128):
+ with T.block(""):
+ T.reads()
+ T.writes()
+ T.evaluate(0)
+ """
+ _assert_print(root_block_implicitly, expected_output)
+ _assert_print(root_block_explicitly, expected_output)
+
+
if __name__ == "__main__":
tvm.testing.main()