This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 697fdb2cb7 [TVMScript] Comments and docstrings printing (#13839)
697fdb2cb7 is described below
commit 697fdb2cb7dc7ad07ed826f908390b88106cc98f
Author: Yaxing Cai <[email protected]>
AuthorDate: Wed Jan 25 21:42:31 2023 -0800
[TVMScript] Comments and docstrings printing (#13839)
This PR introduces the `CommentDoc` for comments printing and
`DocStringDoc` for docstring printing. It enables to add free comments and
docstring as `stmt` in printing, e.g.
```python
# comment 1
# comment 2
"""
docstring 1
docstring 2
"""
```
The free here means to not be bound to any `stmt`, but acts as a single
`stmt`, similar to `ExprStmtDoc` for `ExprDoc`.
This PR also introduces an example for the `CommentDoc`, as follow up of
#13819.
In the old printer, we always print a `# with T.block("root"):`, when there
is an implicit root block skipped when printing. For example,
```
@T.prim_func
def main():
# with T.block("root"):
a = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block(""):
...
```
We bring this syntax reminder back in this PR.
In addition, we introduce a field of `ir_usage` and `print_headers` into
the printer configuration, to support the printing of headers for `IRModule`
and `PrimFunc`. For example,
```python
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module():
@T.prim_func
def func():
...
```
---
include/tvm/script/printer/doc.h | 44 +++++++++++++++++
include/tvm/script/printer/ir_docsifier.h | 4 ++
python/tvm/script/printer/doc.py | 20 ++++++++
src/script/printer/doc.cc | 22 +++++++++
src/script/printer/doc_printer/base_doc_printer.cc | 4 ++
src/script/printer/doc_printer/base_doc_printer.h | 10 ++++
.../printer/doc_printer/python_doc_printer.cc | 34 ++++++++++---
src/script/printer/ir/ir.cc | 3 +-
src/script/printer/ir/utils.h | 1 +
src/script/printer/tir/function.cc | 4 +-
src/script/printer/tir/utils.h | 1 +
src/script/printer/utils.h | 20 ++++++++
.../python/unittest/test_tvmscript_printer_doc.py | 28 +++++++++++
tests/python/unittest/test_tvmscript_printer_ir.py | 3 ++
.../test_tvmscript_printer_python_doc_printer.py | 56 +++++++++++++++++++++-
.../python/unittest/test_tvmscript_printer_tir.py | 18 ++++++-
16 files changed, 261 insertions(+), 11 deletions(-)
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 6504e2c284..6321caa4e0 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -1194,6 +1194,50 @@ class ClassDoc : public StmtDoc {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode);
};
+/*!
+ * \brief Doc that represents comment.
+ *
+ * \sa CommentDoc
+ */
+class CommentDocNode : public StmtDocNode {
+ public:
+ static constexpr const char* _type_key = "script.printer.CommentDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(CommentDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of CommentDocNode.
+ *
+ * \sa CommentDocNode
+ */
+class CommentDoc : public StmtDoc {
+ public:
+ explicit CommentDoc(String comment);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CommentDoc, StmtDoc,
CommentDocNode);
+};
+
+/*!
+ * \brief Doc that represents docstring.
+ *
+ * \sa DocStringDoc
+ */
+class DocStringDocNode : public StmtDocNode {
+ public:
+ static constexpr const char* _type_key = "script.printer.DocStringDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DocStringDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of DocStringDocNode.
+ *
+ * \sa DocStringDocNode
+ */
+class DocStringDoc : public StmtDoc {
+ public:
+ explicit DocStringDoc(String docs);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DocStringDoc, StmtDoc,
DocStringDocNode);
+};
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/include/tvm/script/printer/ir_docsifier.h
b/include/tvm/script/printer/ir_docsifier.h
index 67fa96ef80..c41827fe95 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -24,6 +24,7 @@
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier_functor.h>
+#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
@@ -148,6 +149,8 @@ class IRDocsifierNode : public Object {
std::unordered_set<String> defined_names;
/*! \brief Common prefixes of variable usages */
std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
+ /*! \brief The IR usages for headers printing */
+ std::unordered_set<std::string> ir_usage;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("frames", &frames);
@@ -156,6 +159,7 @@ class IRDocsifierNode : public Object {
// `obj2info` is not visited
// `defined_names` is not visited
// `common_prefix` is not visited
+ // `ir_usage` is not visited
}
static constexpr const char* _type_key = "script.printer.IRDocsifier";
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index 5a4a4cd67a..9a6e7f1b8c 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -521,3 +521,23 @@ class ClassDoc(StmtDoc):
decorators,
body,
)
+
+
+@register_object("script.printer.CommentDoc")
+class CommentDoc(StmtDoc):
+ """Doc that represents comment."""
+
+ def __init__(self, comment: str):
+ self.__init_handle_by_constructor__(
+ _ffi_api.CommentDoc, comment # type: ignore # pylint:
disable=no-member
+ )
+
+
+@register_object("script.printer.DocStringDoc")
+class DocStringDoc(StmtDoc):
+ """Doc that represents docstring."""
+
+ def __init__(self, docs: str):
+ self.__init_handle_by_constructor__(
+ _ffi_api.DocStringDoc, docs # type: ignore # pylint:
disable=no-member
+ )
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index 89f6b7c8b1..1db4e090dc 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -221,6 +221,18 @@ ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators,
Array<StmtDoc> body) {
this->data_ = std::move(n);
}
+CommentDoc::CommentDoc(String comment) {
+ ObjectPtr<CommentDocNode> n = make_object<CommentDocNode>();
+ n->comment = comment;
+ this->data_ = std::move(n);
+}
+
+DocStringDoc::DocStringDoc(String docs) {
+ ObjectPtr<DocStringDocNode> n = make_object<DocStringDocNode>();
+ n->comment = docs;
+ this->data_ = std::move(n);
+}
+
TVM_REGISTER_NODE_TYPE(DocNode);
TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
.set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
@@ -365,6 +377,16 @@ TVM_REGISTER_GLOBAL("script.printer.ClassDoc")
return ClassDoc(name, decorators, body);
});
+TVM_REGISTER_NODE_TYPE(CommentDocNode);
+TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String
comment) {
+ return CommentDoc(comment);
+});
+
+TVM_REGISTER_NODE_TYPE(DocStringDocNode);
+TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String
docs) {
+ return DocStringDoc(docs);
+});
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/doc_printer/base_doc_printer.cc
b/src/script/printer/doc_printer/base_doc_printer.cc
index a3a5c06ede..8df599347f 100644
--- a/src/script/printer/doc_printer/base_doc_printer.cc
+++ b/src/script/printer/doc_printer/base_doc_printer.cc
@@ -316,6 +316,10 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<CommentDocNode>()) {
+ PrintTypedDoc(GetRef<CommentDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<DocStringDocNode>()) {
+ PrintTypedDoc(GetRef<DocStringDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
diff --git a/src/script/printer/doc_printer/base_doc_printer.h
b/src/script/printer/doc_printer/base_doc_printer.h
index 7851ce061b..f5cf40a233 100644
--- a/src/script/printer/doc_printer/base_doc_printer.h
+++ b/src/script/printer/doc_printer/base_doc_printer.h
@@ -204,6 +204,16 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;
+ /*!
+ * \brief Virtual method to print a CommentDoc
+ */
+ virtual void PrintTypedDoc(const CommentDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a DocStringDoc
+ */
+ virtual void PrintTypedDoc(const DocStringDoc& doc) = 0;
+
/*!
* \brief Increase the indent level of any content to be
* printed after this call
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc
b/src/script/printer/doc_printer/python_doc_printer.cc
index ce6b8e7f42..334f76f722 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -169,6 +169,8 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const ScopeDoc& doc) final;
void PrintTypedDoc(const FunctionDoc& doc) final;
void PrintTypedDoc(const ClassDoc& doc) final;
+ void PrintTypedDoc(const CommentDoc& doc) final;
+ void PrintTypedDoc(const DocStringDoc& doc) final;
private:
void NewLineWithoutIndent() { output_ << "\n"; }
@@ -253,11 +255,19 @@ class PythonDocPrinter : public DocPrinter {
}
}
- void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
+ void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) {
if (stmt->comment.defined()) {
std::vector<std::string> comment_lines =
support::Split(stmt->comment.value(), '\n');
+ bool first_line = true;
for (const std::string& line : comment_lines) {
- output_ << "# " << line;
+ if (first_line) {
+ output_ << "# " << line;
+ first_line = false;
+ } else {
+ NewLine() << "# " << line;
+ }
+ }
+ if (new_line) {
NewLine();
}
}
@@ -523,7 +533,7 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
}
void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
- MaybePrintCommentWithNewLine(doc);
+ MaybePrintCommenMultiLines(doc, true);
output_ << "if ";
PrintDoc(doc->predicate);
output_ << ":";
@@ -538,7 +548,7 @@ void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
}
void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
- MaybePrintCommentWithNewLine(doc);
+ MaybePrintCommenMultiLines(doc, true);
output_ << "while ";
PrintDoc(doc->predicate);
output_ << ":";
@@ -547,7 +557,7 @@ void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
}
void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
- MaybePrintCommentWithNewLine(doc);
+ MaybePrintCommenMultiLines(doc, true);
output_ << "for ";
if (const auto* tuple = doc->lhs.as<TupleDocNode>()) {
if (tuple->elements.size() == 1) {
@@ -567,7 +577,7 @@ void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
}
void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
- MaybePrintCommentWithNewLine(doc);
+ MaybePrintCommenMultiLines(doc, true);
output_ << "with ";
PrintDoc(doc->rhs);
if (doc->lhs != nullptr) {
@@ -642,6 +652,18 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
NewLineWithoutIndent();
}
+void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) {
+ if (doc->comment.defined()) {
+ MaybePrintCommenMultiLines(doc, false);
+ }
+}
+
+void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) {
+ if (doc->comment.defined() && !doc->comment.value().empty()) {
+ output_ << "\"\"\"" << doc->comment.value() << "\"\"\"";
+ }
+}
+
String DocToPythonScript(Doc doc, const PrinterConfig& cfg) {
if (cfg->num_context_lines < 0) {
cfg->num_context_lines = std::numeric_limits<int32_t>::max();
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index 4a246e1692..7f7857dba6 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -119,7 +119,8 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const
PrinterConfig& cfg) {
return s.value();
}
}
- Doc doc = IRDocsifier(cfg)->AsDoc(mod, ObjectPath::Root());
+ IRDocsifier d(cfg);
+ Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
}
diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h
index d20756e608..a05030516f 100644
--- a/src/script/printer/ir/utils.h
+++ b/src/script/printer/ir/utils.h
@@ -36,6 +36,7 @@ namespace printer {
/*! \brief Creates the IR common prefix, which is by default `I` */
inline ExprDoc IR(const IRDocsifier& d, const String& attr) {
+ d->ir_usage.insert("ir");
return IdDoc(d->cfg->ir_prefix)->Attr(attr);
}
diff --git a/src/script/printer/tir/function.cc
b/src/script/printer/tir/function.cc
index fbcc2fca3b..65f3db5b4f 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -153,6 +153,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (implicit_root_block) {
tir::Block root_block = implicit_root_block.value();
ObjectPath root_block_p = p->Attr("body")->Attr("body");
+ (*frame)->stmts.push_back(CommentDoc("with T.block(\"root\"):"));
// Handle root block `alloc_buffer`
for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
tir::Buffer buffer = root_block->alloc_buffers[i];
@@ -181,7 +182,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});
std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) {
- Doc doc = IRDocsifier(cfg)->AsDoc(obj, ObjectPath::Root());
+ IRDocsifier d(cfg);
+ Doc doc = HeaderWrapper(d, d->AsDoc(obj, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
}
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
index 88094ee816..0eead9a577 100644
--- a/src/script/printer/tir/utils.h
+++ b/src/script/printer/tir/utils.h
@@ -74,6 +74,7 @@ class TIRFrame : public Frame {
/*! \brief Creates the TIR common prefix, which is by default `T` */
inline ExprDoc TIR(const IRDocsifier& d, const String& attr) {
+ d->ir_usage.insert("tir");
return IdDoc(d->cfg->tir_prefix)->Attr(attr);
}
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index cb20eb363d..e90fbc0fb3 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -69,6 +69,26 @@ inline std::string DType2Str(const runtime::DataType& dtype)
{
return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype);
}
+/*! \brief Add headers as comments to doc if needed */
+inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) {
+ if (d->ir_usage.size()) {
+ Array<StmtDoc> stmts;
+ if (d->ir_usage.count("ir")) {
+ stmts.push_back(CommentDoc("from tvm.script import ir as " +
d->cfg->ir_prefix));
+ }
+ if (d->ir_usage.count("tir")) {
+ stmts.push_back(CommentDoc("from tvm.script import tir as " +
d->cfg->tir_prefix));
+ }
+ if (d->ir_usage.count("relax")) {
+ stmts.push_back(CommentDoc("from tvm.script import relax as " +
d->cfg->relax_prefix));
+ }
+ stmts.push_back(CommentDoc(""));
+ stmts.push_back(Downcast<StmtDoc>(doc));
+ return StmtBlockDoc(stmts);
+ }
+ return doc;
+}
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py
b/tests/python/unittest/test_tvmscript_printer_doc.py
index 16a0c31ac3..6353627c58 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -29,7 +29,9 @@ from tvm.script.printer.doc import (
AttrAccessDoc,
CallDoc,
ClassDoc,
+ CommentDoc,
DictDoc,
+ DocStringDoc,
ExprStmtDoc,
ForDoc,
FunctionDoc,
@@ -505,6 +507,32 @@ def test_class_doc(decorators, body):
assert list(doc.body) == body
[email protected](
+ "comment",
+ [
+ "",
+ "test comment 1",
+ "test comment 1\ntest comment 1",
+ ],
+)
+def test_comment_doc(comment):
+ doc = CommentDoc(comment)
+ assert doc.comment == comment
+
+
[email protected](
+ "comment",
+ [
+ "",
+ "test comment 1",
+ "test comment 1\ntest comment 1",
+ ],
+)
+def test_doc_string_doc(comment):
+ doc = DocStringDoc(comment)
+ assert doc.comment == comment
+
+
def test_stmt_doc_comment():
doc = ExprStmtDoc(IdDoc("x"))
assert doc.comment is None
diff --git a/tests/python/unittest/test_tvmscript_printer_ir.py
b/tests/python/unittest/test_tvmscript_printer_ir.py
index c3da3d8c70..6b3ac19a5e 100644
--- a/tests/python/unittest/test_tvmscript_printer_ir.py
+++ b/tests/python/unittest/test_tvmscript_printer_ir.py
@@ -37,6 +37,9 @@ def test_ir_module():
_assert_print(
mod,
"""
+# from tvm.script import ir as I
+# from tvm.script import tir as T
+
@I.ir_module
class Module:
@T.prim_func
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index d87f9ec69e..75beb59d02 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -23,7 +23,9 @@ from tvm.script.printer.doc import (
AssignDoc,
CallDoc,
ClassDoc,
+ CommentDoc,
DictDoc,
+ DocStringDoc,
ExprStmtDoc,
ForDoc,
FunctionDoc,
@@ -53,7 +55,7 @@ def format_script(s: str) -> str:
non_empty_lines = [line for line in s.splitlines() if line and not
line.isspace()]
if not non_empty_lines:
# no actual content
- return "\n"
+ return ""
line_indents = [len(line) - len(line.lstrip(" ")) for line in
non_empty_lines]
spaces_to_remove = min(line_indents)
@@ -887,6 +889,58 @@ def test_print_class_doc(decorators, body, expected):
assert to_python_script(doc) == format_script(expected)
[email protected](
+ "comment, expected",
+ [
+ (
+ "",
+ "",
+ ),
+ (
+ "test comment 1",
+ "# test comment 1",
+ ),
+ (
+ "test comment 1\ntest comment 2",
+ """
+ # test comment 1
+ # test comment 2
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_comment_doc(comment, expected):
+ doc = CommentDoc(comment)
+ assert to_python_script(doc) == format_script(expected)
+
+
[email protected](
+ "comment, expected",
+ [
+ (
+ "",
+ "",
+ ),
+ (
+ "test comment 1",
+ '"""test comment 1"""',
+ ),
+ (
+ "test comment 1\ntest comment 2",
+ '''
+ """test comment 1
+ test comment 2"""
+ ''',
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_doc_string_doc(comment, expected):
+ doc = DocStringDoc(comment)
+ assert to_python_script(doc) == format_script(expected)
+
+
@pytest.mark.parametrize(
"doc, comment, expected",
[
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index ec69c54396..49a33cd0f0 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -41,6 +41,8 @@ def test_prim_func():
_assert_print(
func,
expected="""
+# from tvm.script import tir as T
+
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256),
"float32")):
T.evaluate(0)""",
@@ -62,6 +64,8 @@ def test_prim_func_no_sugar_inlined_buffer():
_assert_print(
func,
expected="""
+# from tvm.script import tir as T
+
@T.prim_func
def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
A = T.match_buffer(a, (128, 128))
@@ -86,6 +90,8 @@ def test_prim_func_no_sugar_shared_buffer_data():
_assert_print(
func,
expected="""
+# from tvm.script import tir as T
+
@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
@@ -698,8 +704,12 @@ def test_remap():
v3 = T.axis.spatial(128, i3 - 1)
v4, v5 = T.axis.remap("RS", [i4, i5])
- expected_output = """@T.prim_func
+ expected_output = """
+# from tvm.script import tir as T
+
[email protected]_func
def main():
+ # with T.block("root"):
for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
with T.block("update"):
v0 = T.axis.spatial(128, i0 + 1)
@@ -731,8 +741,12 @@ def test_root_block():
with T.block():
T.evaluate(0)
- expected_output = """@T.prim_func
+ expected_output = """
+# from tvm.script import tir as T
+
[email protected]_func
def main():
+ # with T.block("root"):
a = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block(""):