This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 31a23d8 [TVMScript] fixing block attr printing bug (#9667)
31a23d8 is described below
commit 31a23d8833f173ba761f88f7f3931f80fc646b3e
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Dec 7 03:06:36 2021 -0500
[TVMScript] fixing block attr printing bug (#9667)
---
src/printer/tvmscript_printer.cc | 10 +++++++---
tests/python/unittest/test_tvmscript_roundtrip.py | 13 +++++++++++++
2 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 9fb04a2..c233f1a 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1238,10 +1238,14 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc&
primFunc) {
body << "# body" << Doc::NewLine();
if (op->body->IsInstance<BlockRealizeNode>() &&
op->body.as<BlockRealizeNode>()->iter_values.empty()) {
- // Skip print root block
- body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
const BlockNode* block = op->body.as<BlockRealizeNode>()->block.get();
- body << PrintBlockBody(block);
+ if (block->annotations.empty()) {
+ // Skip print root block
+ body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
+ body << PrintBlockBody(block);
+ } else {
+ body << PrintBody(op->body);
+ }
} else {
body << PrintBody(op->body);
}
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 98eaa02..bf1235a 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3242,5 +3242,18 @@ def test_func_with_target_spec_by_str():
tvm.ir.assert_structural_equal(func, rt_func, True)
[email protected]_func
+def func_root_attr():
+ with T.block("root"):
+ T.block_attr({"a": "0"})
+ T.evaluate(0)
+
+
+def test_root_attr():
+ func = func_root_attr
+ rt_func = tvm.script.from_source(func.script(show_meta=True))
+ tvm.ir.assert_structural_equal(func, rt_func, True)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))