This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ac9fb98857 [TVMScript] Implicit root block syntax sugar for TVMScript 
printer (#13819)
ac9fb98857 is described below

commit ac9fb98857f68d5406c902db209548a1a6a1e9c1
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Jan 20 21:25:55 2023 -0800

    [TVMScript] Implicit root block syntax sugar for TVMScript printer (#13819)
    
    This PR implements the syntax sugar of implicit root block for new 
TVMScript printer. This syntax sugar will skip the `T.block("root")`, when the 
root block realize is simple and we shall reconstruct that root block via 
`tvm::tir::ScriptComplete` when roundtripping. For example, it will change
    ```python
    @T.prim_func
    def root_block_explicitly():
      with T.block("root"):
        a = T.alloc_buffer([128, 128])
        for i, j in T.grid(128, 128):
          with T.block():
            T.evaluate(0)
    ```
    into
    ```python
    @T.prim_func
    def main():
      a = T.alloc_buffer((128, 128))
      for i, j in T.grid(128, 128):
        with T.block(""):
          T.evaluate(0)
    ```
---
 src/script/printer/tir/function.cc                 | 35 ++++++++++++++-
 .../python/unittest/test_tvmscript_printer_tir.py  | 52 +++++++++++++++++-----
 2 files changed, 74 insertions(+), 13 deletions(-)

diff --git a/src/script/printer/tir/function.cc 
b/src/script/printer/tir/function.cc
index 40957fcffa..ea7d56e165 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -131,7 +131,40 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
         }
       }
       // Step 4. Handle `func->body`
-      AsDocBody(func->body, p->Attr("body"), frame->get(), d);
+      Optional<tir::Block> implicit_root_block = [&]() -> Optional<tir::Block> 
{
+        const tir::BlockRealizeNode* root_block_realize = 
func->body.as<tir::BlockRealizeNode>();
+        if (root_block_realize && !root_block_realize->iter_values.size() &&
+            tir::is_one(root_block_realize->predicate)) {
+          tir::Block root_block = root_block_realize->block;
+          if (!root_block->annotations.size() && 
!root_block->match_buffers.size() &&
+              !root_block->reads.size() && !root_block->writes.size() &&
+              !root_block->init.defined()) {
+            const tir::BlockRealizeNode* block_realize =
+                root_block->body.as<tir::BlockRealizeNode>();
+            if (root_block->alloc_buffers.size() ||
+                (block_realize && block_realize->block->iter_vars.size()) ||
+                (!block_realize && 
tir::ContainsNode<tir::BlockRealizeNode>(root_block->body))) {
+              return root_block;
+            }
+          }
+        }
+        return NullOpt;
+      }();
+      if (implicit_root_block) {
+        tir::Block root_block = implicit_root_block.value();
+        ObjectPath root_block_p = p->Attr("body")->Attr("body");
+        // Handle root block `alloc_buffer`
+        for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
+          tir::Buffer buffer = root_block->alloc_buffers[i];
+          ObjectPath buffer_p = 
root_block_p->Attr("alloc_buffers")->ArrayIndex(i);
+          IdDoc lhs = DefineBuffer(buffer, *frame, d);
+          ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, 
*frame, d);
+          (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+        }
+        AsDocBody(root_block->body, root_block_p->Attr("body"), frame->get(), 
d);
+      } else {
+        AsDocBody(func->body, p->Attr("body"), frame->get(), d);
+      }
       Optional<ExprDoc> ret_type = NullOpt;
       if (func->ret_type.defined()) {
         const auto* as_tuple = func->ret_type.as<TupleTypeNode>();
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py 
b/tests/python/unittest/test_tvmscript_printer_tir.py
index 5d86a88608..d57d104670 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -717,21 +717,49 @@ def test_remap():
 
     expected_output = """@T.prim_func
 def main():
-    with T.block("root"):
-        T.reads()
-        T.writes()
-        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
-            with T.block("update"):
-                v0 = T.axis.spatial(128, i0 + 1)
-                v1, v2 = T.axis.remap("SR", [i1, i2])
-                v3 = T.axis.spatial(128, i3 - 1)
-                v4, v5 = T.axis.remap("RS", [i4, i5])
-                T.reads()
-                T.writes()
-                T.evaluate(0)"""
+    for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+        with T.block("update"):
+            v0 = T.axis.spatial(128, i0 + 1)
+            v1, v2 = T.axis.remap("SR", [i1, i2])
+            v3 = T.axis.spatial(128, i3 - 1)
+            v4, v5 = T.axis.remap("RS", [i4, i5])
+            T.reads()
+            T.writes()
+            T.evaluate(0)"""
     _assert_print(block_with_remap_explicitly, expected_output)
     _assert_print(block_with_remap_implicitly, expected_output)
 
 
+def test_root_block():
+    from tvm.script import tir as T
+
+    @T.prim_func
+    def root_block_implicitly():
+        a = T.alloc_buffer([128, 128])
+        for i, j in T.grid(128, 128):
+            with T.block():
+                T.evaluate(0)
+
+    @T.prim_func
+    def root_block_explicitly():
+        with T.block("root"):
+            a = T.alloc_buffer([128, 128])
+            for i, j in T.grid(128, 128):
+                with T.block():
+                    T.evaluate(0)
+
+    expected_output = """@T.prim_func
+def main():
+    a = T.alloc_buffer((128, 128))
+    for i, j in T.grid(128, 128):
+        with T.block(""):
+            T.reads()
+            T.writes()
+            T.evaluate(0)
+    """
+    _assert_print(root_block_implicitly, expected_output)
+    _assert_print(root_block_explicitly, expected_output)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to