This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 31a23d8  [TVMScript] fixing block attr printing bug (#9667)
31a23d8 is described below

commit 31a23d8833f173ba761f88f7f3931f80fc646b3e
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Dec 7 03:06:36 2021 -0500

    [TVMScript] fixing block attr printing bug (#9667)
---
 src/printer/tvmscript_printer.cc                  | 10 +++++++---
 tests/python/unittest/test_tvmscript_roundtrip.py | 13 +++++++++++++
 2 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 9fb04a2..c233f1a 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1238,10 +1238,14 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& 
primFunc) {
   body << "# body" << Doc::NewLine();
   if (op->body->IsInstance<BlockRealizeNode>() &&
       op->body.as<BlockRealizeNode>()->iter_values.empty()) {
-    // Skip print root block
-    body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
     const BlockNode* block = op->body.as<BlockRealizeNode>()->block.get();
-    body << PrintBlockBody(block);
+    if (block->annotations.empty()) {
+      // Skip print root block
+      body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
+      body << PrintBlockBody(block);
+    } else {
+      body << PrintBody(op->body);
+    }
   } else {
     body << PrintBody(op->body);
   }
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 98eaa02..bf1235a 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3242,5 +3242,18 @@ def test_func_with_target_spec_by_str():
     tvm.ir.assert_structural_equal(func, rt_func, True)
 
 
[email protected]_func
+def func_root_attr():
+    with T.block("root"):
+        T.block_attr({"a": "0"})
+        T.evaluate(0)
+
+
+def test_root_attr():
+    func = func_root_attr
+    rt_func = tvm.script.from_source(func.script(show_meta=True))
+    tvm.ir.assert_structural_equal(func, rt_func, True)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to