This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 697fdb2cb7 [TVMScript] Comments and docstrings printing (#13839)
697fdb2cb7 is described below

commit 697fdb2cb7dc7ad07ed826f908390b88106cc98f
Author: Yaxing Cai <[email protected]>
AuthorDate: Wed Jan 25 21:42:31 2023 -0800

    [TVMScript] Comments and docstrings printing (#13839)
    
    This PR introduces the `CommentDoc` for comments printing and 
`DocStringDoc` for docstring printing. It enables to add free comments and 
docstring as `stmt` in printing, e.g.
    ```python
    # comment 1
    # comment 2
    """
    docstring 1
    docstring 2
    """
    ```
    The free here means to not be bound to any `stmt`, but acts as a single 
`stmt`, similar to `ExprStmtDoc` for `ExprDoc`.
    
    This PR also introduces an example for the `CommentDoc`, as follow up of 
#13819.
    In the old printer, we always print a `# with T.block("root"):`, when there 
is an implicit root block skipped when printing. For example,
    ```
    @T.prim_func
    def main():
      # with T.block("root"):
      a = T.alloc_buffer((128, 128))
      for i, j in T.grid(128, 128):
        with T.block(""):
          ...
    ```
    We bring this syntax reminder back in this PR.
    In addition, we introduce a field of `ir_usage` and `print_headers` into 
the printer configuration, to support the printing of headers for `IRModule` 
and `PrimFunc`. For example,
    
    ```python
    # from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module():
      @T.prim_func
      def func():
        ...
    ```
---
 include/tvm/script/printer/doc.h                   | 44 +++++++++++++++++
 include/tvm/script/printer/ir_docsifier.h          |  4 ++
 python/tvm/script/printer/doc.py                   | 20 ++++++++
 src/script/printer/doc.cc                          | 22 +++++++++
 src/script/printer/doc_printer/base_doc_printer.cc |  4 ++
 src/script/printer/doc_printer/base_doc_printer.h  | 10 ++++
 .../printer/doc_printer/python_doc_printer.cc      | 34 ++++++++++---
 src/script/printer/ir/ir.cc                        |  3 +-
 src/script/printer/ir/utils.h                      |  1 +
 src/script/printer/tir/function.cc                 |  4 +-
 src/script/printer/tir/utils.h                     |  1 +
 src/script/printer/utils.h                         | 20 ++++++++
 .../python/unittest/test_tvmscript_printer_doc.py  | 28 +++++++++++
 tests/python/unittest/test_tvmscript_printer_ir.py |  3 ++
 .../test_tvmscript_printer_python_doc_printer.py   | 56 +++++++++++++++++++++-
 .../python/unittest/test_tvmscript_printer_tir.py  | 18 ++++++-
 16 files changed, 261 insertions(+), 11 deletions(-)

diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 6504e2c284..6321caa4e0 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -1194,6 +1194,50 @@ class ClassDoc : public StmtDoc {
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode);
 };
 
+/*!
+ * \brief Doc that represents comment.
+ *
+ * \sa CommentDoc
+ */
+class CommentDocNode : public StmtDocNode {
+ public:
+  static constexpr const char* _type_key = "script.printer.CommentDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CommentDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of CommentDocNode.
+ *
+ * \sa CommentDocNode
+ */
+class CommentDoc : public StmtDoc {
+ public:
+  explicit CommentDoc(String comment);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CommentDoc, StmtDoc, 
CommentDocNode);
+};
+
+/*!
+ * \brief Doc that represents docstring.
+ *
+ * \sa DocStringDoc
+ */
+class DocStringDocNode : public StmtDocNode {
+ public:
+  static constexpr const char* _type_key = "script.printer.DocStringDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DocStringDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of DocStringDocNode.
+ *
+ * \sa DocStringDocNode
+ */
+class DocStringDoc : public StmtDoc {
+ public:
+  explicit DocStringDoc(String docs);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DocStringDoc, StmtDoc, 
DocStringDocNode);
+};
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/include/tvm/script/printer/ir_docsifier.h 
b/include/tvm/script/printer/ir_docsifier.h
index 67fa96ef80..c41827fe95 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -24,6 +24,7 @@
 #include <tvm/script/printer/doc.h>
 #include <tvm/script/printer/ir_docsifier_functor.h>
 
+#include <string>
 #include <unordered_map>
 #include <unordered_set>
 #include <utility>
@@ -148,6 +149,8 @@ class IRDocsifierNode : public Object {
   std::unordered_set<String> defined_names;
   /*! \brief Common prefixes of variable usages */
   std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
+  /*! \brief The IR usages for headers printing */
+  std::unordered_set<std::string> ir_usage;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("frames", &frames);
@@ -156,6 +159,7 @@ class IRDocsifierNode : public Object {
     // `obj2info` is not visited
     // `defined_names` is not visited
     // `common_prefix` is not visited
+    // `ir_usage` is not visited
   }
 
   static constexpr const char* _type_key = "script.printer.IRDocsifier";
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index 5a4a4cd67a..9a6e7f1b8c 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -521,3 +521,23 @@ class ClassDoc(StmtDoc):
             decorators,
             body,
         )
+
+
+@register_object("script.printer.CommentDoc")
+class CommentDoc(StmtDoc):
+    """Doc that represents comment."""
+
+    def __init__(self, comment: str):
+        self.__init_handle_by_constructor__(
+            _ffi_api.CommentDoc, comment  # type: ignore # pylint: 
disable=no-member
+        )
+
+
+@register_object("script.printer.DocStringDoc")
+class DocStringDoc(StmtDoc):
+    """Doc that represents docstring."""
+
+    def __init__(self, docs: str):
+        self.__init_handle_by_constructor__(
+            _ffi_api.DocStringDoc, docs  # type: ignore # pylint: 
disable=no-member
+        )
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index 89f6b7c8b1..1db4e090dc 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -221,6 +221,18 @@ ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, 
Array<StmtDoc> body) {
   this->data_ = std::move(n);
 }
 
+CommentDoc::CommentDoc(String comment) {
+  ObjectPtr<CommentDocNode> n = make_object<CommentDocNode>();
+  n->comment = comment;
+  this->data_ = std::move(n);
+}
+
+DocStringDoc::DocStringDoc(String docs) {
+  ObjectPtr<DocStringDocNode> n = make_object<DocStringDocNode>();
+  n->comment = docs;
+  this->data_ = std::move(n);
+}
+
 TVM_REGISTER_NODE_TYPE(DocNode);
 TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
     .set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
@@ -365,6 +377,16 @@ TVM_REGISTER_GLOBAL("script.printer.ClassDoc")
       return ClassDoc(name, decorators, body);
     });
 
+TVM_REGISTER_NODE_TYPE(CommentDocNode);
+TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String 
comment) {
+  return CommentDoc(comment);
+});
+
+TVM_REGISTER_NODE_TYPE(DocStringDocNode);
+TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String 
docs) {
+  return DocStringDoc(docs);
+});
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/src/script/printer/doc_printer/base_doc_printer.cc 
b/src/script/printer/doc_printer/base_doc_printer.cc
index a3a5c06ede..8df599347f 100644
--- a/src/script/printer/doc_printer/base_doc_printer.cc
+++ b/src/script/printer/doc_printer/base_doc_printer.cc
@@ -316,6 +316,10 @@ void DocPrinter::PrintDoc(const Doc& doc) {
     PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
   } else if (const auto* doc_node = doc.as<ClassDocNode>()) {
     PrintTypedDoc(GetRef<ClassDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<CommentDocNode>()) {
+    PrintTypedDoc(GetRef<CommentDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<DocStringDocNode>()) {
+    PrintTypedDoc(GetRef<DocStringDoc>(doc_node));
   } else {
     LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
     throw;
diff --git a/src/script/printer/doc_printer/base_doc_printer.h 
b/src/script/printer/doc_printer/base_doc_printer.h
index 7851ce061b..f5cf40a233 100644
--- a/src/script/printer/doc_printer/base_doc_printer.h
+++ b/src/script/printer/doc_printer/base_doc_printer.h
@@ -204,6 +204,16 @@ class DocPrinter {
    */
   virtual void PrintTypedDoc(const ClassDoc& doc) = 0;
 
+  /*!
+   * \brief Virtual method to print a CommentDoc
+   */
+  virtual void PrintTypedDoc(const CommentDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print a DocStringDoc
+   */
+  virtual void PrintTypedDoc(const DocStringDoc& doc) = 0;
+
   /*!
    * \brief Increase the indent level of any content to be
    *        printed after this call
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc 
b/src/script/printer/doc_printer/python_doc_printer.cc
index ce6b8e7f42..334f76f722 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -169,6 +169,8 @@ class PythonDocPrinter : public DocPrinter {
   void PrintTypedDoc(const ScopeDoc& doc) final;
   void PrintTypedDoc(const FunctionDoc& doc) final;
   void PrintTypedDoc(const ClassDoc& doc) final;
+  void PrintTypedDoc(const CommentDoc& doc) final;
+  void PrintTypedDoc(const DocStringDoc& doc) final;
 
  private:
   void NewLineWithoutIndent() { output_ << "\n"; }
@@ -253,11 +255,19 @@ class PythonDocPrinter : public DocPrinter {
     }
   }
 
-  void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
+  void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) {
     if (stmt->comment.defined()) {
       std::vector<std::string> comment_lines = 
support::Split(stmt->comment.value(), '\n');
+      bool first_line = true;
       for (const std::string& line : comment_lines) {
-        output_ << "# " << line;
+        if (first_line) {
+          output_ << "# " << line;
+          first_line = false;
+        } else {
+          NewLine() << "# " << line;
+        }
+      }
+      if (new_line) {
         NewLine();
       }
     }
@@ -523,7 +533,7 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
 }
 
 void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
-  MaybePrintCommentWithNewLine(doc);
+  MaybePrintCommenMultiLines(doc, true);
   output_ << "if ";
   PrintDoc(doc->predicate);
   output_ << ":";
@@ -538,7 +548,7 @@ void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
 }
 
 void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
-  MaybePrintCommentWithNewLine(doc);
+  MaybePrintCommenMultiLines(doc, true);
   output_ << "while ";
   PrintDoc(doc->predicate);
   output_ << ":";
@@ -547,7 +557,7 @@ void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
 }
 
 void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
-  MaybePrintCommentWithNewLine(doc);
+  MaybePrintCommenMultiLines(doc, true);
   output_ << "for ";
   if (const auto* tuple = doc->lhs.as<TupleDocNode>()) {
     if (tuple->elements.size() == 1) {
@@ -567,7 +577,7 @@ void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
 }
 
 void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
-  MaybePrintCommentWithNewLine(doc);
+  MaybePrintCommenMultiLines(doc, true);
   output_ << "with ";
   PrintDoc(doc->rhs);
   if (doc->lhs != nullptr) {
@@ -642,6 +652,18 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
   NewLineWithoutIndent();
 }
 
+void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) {
+  if (doc->comment.defined()) {
+    MaybePrintCommenMultiLines(doc, false);
+  }
+}
+
+void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) {
+  if (doc->comment.defined() && !doc->comment.value().empty()) {
+    output_ << "\"\"\"" << doc->comment.value() << "\"\"\"";
+  }
+}
+
 String DocToPythonScript(Doc doc, const PrinterConfig& cfg) {
   if (cfg->num_context_lines < 0) {
     cfg->num_context_lines = std::numeric_limits<int32_t>::max();
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index 4a246e1692..7f7857dba6 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -119,7 +119,8 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const 
PrinterConfig& cfg) {
       return s.value();
     }
   }
-  Doc doc = IRDocsifier(cfg)->AsDoc(mod, ObjectPath::Root());
+  IRDocsifier d(cfg);
+  Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root()));
   return DocToPythonScript(doc, cfg);
 }
 
diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h
index d20756e608..a05030516f 100644
--- a/src/script/printer/ir/utils.h
+++ b/src/script/printer/ir/utils.h
@@ -36,6 +36,7 @@ namespace printer {
 
 /*! \brief Creates the IR common prefix, which is by default `I` */
 inline ExprDoc IR(const IRDocsifier& d, const String& attr) {
+  d->ir_usage.insert("ir");
   return IdDoc(d->cfg->ir_prefix)->Attr(attr);
 }
 
diff --git a/src/script/printer/tir/function.cc 
b/src/script/printer/tir/function.cc
index fbcc2fca3b..65f3db5b4f 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -153,6 +153,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       if (implicit_root_block) {
         tir::Block root_block = implicit_root_block.value();
         ObjectPath root_block_p = p->Attr("body")->Attr("body");
+        (*frame)->stmts.push_back(CommentDoc("with T.block(\"root\"):"));
         // Handle root block `alloc_buffer`
         for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
           tir::Buffer buffer = root_block->alloc_buffers[i];
@@ -181,7 +182,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     });
 
 std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) {
-  Doc doc = IRDocsifier(cfg)->AsDoc(obj, ObjectPath::Root());
+  IRDocsifier d(cfg);
+  Doc doc = HeaderWrapper(d, d->AsDoc(obj, ObjectPath::Root()));
   return DocToPythonScript(doc, cfg);
 }
 
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
index 88094ee816..0eead9a577 100644
--- a/src/script/printer/tir/utils.h
+++ b/src/script/printer/tir/utils.h
@@ -74,6 +74,7 @@ class TIRFrame : public Frame {
 
 /*! \brief Creates the TIR common prefix, which is by default `T` */
 inline ExprDoc TIR(const IRDocsifier& d, const String& attr) {
+  d->ir_usage.insert("tir");
   return IdDoc(d->cfg->tir_prefix)->Attr(attr);
 }
 
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index cb20eb363d..e90fbc0fb3 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -69,6 +69,26 @@ inline std::string DType2Str(const runtime::DataType& dtype) 
{
   return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype);
 }
 
+/*! \brief Add headers as comments to doc if needed */
+inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) {
+  if (d->ir_usage.size()) {
+    Array<StmtDoc> stmts;
+    if (d->ir_usage.count("ir")) {
+      stmts.push_back(CommentDoc("from tvm.script import ir as " + 
d->cfg->ir_prefix));
+    }
+    if (d->ir_usage.count("tir")) {
+      stmts.push_back(CommentDoc("from tvm.script import tir as " + 
d->cfg->tir_prefix));
+    }
+    if (d->ir_usage.count("relax")) {
+      stmts.push_back(CommentDoc("from tvm.script import relax as " + 
d->cfg->relax_prefix));
+    }
+    stmts.push_back(CommentDoc(""));
+    stmts.push_back(Downcast<StmtDoc>(doc));
+    return StmtBlockDoc(stmts);
+  }
+  return doc;
+}
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py 
b/tests/python/unittest/test_tvmscript_printer_doc.py
index 16a0c31ac3..6353627c58 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -29,7 +29,9 @@ from tvm.script.printer.doc import (
     AttrAccessDoc,
     CallDoc,
     ClassDoc,
+    CommentDoc,
     DictDoc,
+    DocStringDoc,
     ExprStmtDoc,
     ForDoc,
     FunctionDoc,
@@ -505,6 +507,32 @@ def test_class_doc(decorators, body):
     assert list(doc.body) == body
 
 
[email protected](
+    "comment",
+    [
+        "",
+        "test comment 1",
+        "test comment 1\ntest comment 1",
+    ],
+)
+def test_comment_doc(comment):
+    doc = CommentDoc(comment)
+    assert doc.comment == comment
+
+
[email protected](
+    "comment",
+    [
+        "",
+        "test comment 1",
+        "test comment 1\ntest comment 1",
+    ],
+)
+def test_doc_string_doc(comment):
+    doc = DocStringDoc(comment)
+    assert doc.comment == comment
+
+
 def test_stmt_doc_comment():
     doc = ExprStmtDoc(IdDoc("x"))
     assert doc.comment is None
diff --git a/tests/python/unittest/test_tvmscript_printer_ir.py 
b/tests/python/unittest/test_tvmscript_printer_ir.py
index c3da3d8c70..6b3ac19a5e 100644
--- a/tests/python/unittest/test_tvmscript_printer_ir.py
+++ b/tests/python/unittest/test_tvmscript_printer_ir.py
@@ -37,6 +37,9 @@ def test_ir_module():
     _assert_print(
         mod,
         """
+# from tvm.script import ir as I
+# from tvm.script import tir as T
+
 @I.ir_module
 class Module:
     @T.prim_func
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py 
b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index d87f9ec69e..75beb59d02 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -23,7 +23,9 @@ from tvm.script.printer.doc import (
     AssignDoc,
     CallDoc,
     ClassDoc,
+    CommentDoc,
     DictDoc,
+    DocStringDoc,
     ExprStmtDoc,
     ForDoc,
     FunctionDoc,
@@ -53,7 +55,7 @@ def format_script(s: str) -> str:
     non_empty_lines = [line for line in s.splitlines() if line and not 
line.isspace()]
     if not non_empty_lines:
         # no actual content
-        return "\n"
+        return ""
 
     line_indents = [len(line) - len(line.lstrip(" ")) for line in 
non_empty_lines]
     spaces_to_remove = min(line_indents)
@@ -887,6 +889,58 @@ def test_print_class_doc(decorators, body, expected):
     assert to_python_script(doc) == format_script(expected)
 
 
[email protected](
+    "comment, expected",
+    [
+        (
+            "",
+            "",
+        ),
+        (
+            "test comment 1",
+            "# test comment 1",
+        ),
+        (
+            "test comment 1\ntest comment 2",
+            """
+            # test comment 1
+            # test comment 2
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_comment_doc(comment, expected):
+    doc = CommentDoc(comment)
+    assert to_python_script(doc) == format_script(expected)
+
+
[email protected](
+    "comment, expected",
+    [
+        (
+            "",
+            "",
+        ),
+        (
+            "test comment 1",
+            '"""test comment 1"""',
+        ),
+        (
+            "test comment 1\ntest comment 2",
+            '''
+            """test comment 1
+            test comment 2"""
+            ''',
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_doc_string_doc(comment, expected):
+    doc = DocStringDoc(comment)
+    assert to_python_script(doc) == format_script(expected)
+
+
 @pytest.mark.parametrize(
     "doc, comment, expected",
     [
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py 
b/tests/python/unittest/test_tvmscript_printer_tir.py
index ec69c54396..49a33cd0f0 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -41,6 +41,8 @@ def test_prim_func():
     _assert_print(
         func,
         expected="""
+# from tvm.script import tir as T
+
 @T.prim_func
 def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), 
"float32")):
     T.evaluate(0)""",
@@ -62,6 +64,8 @@ def test_prim_func_no_sugar_inlined_buffer():
     _assert_print(
         func,
         expected="""
+# from tvm.script import tir as T
+
 @T.prim_func
 def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
     A = T.match_buffer(a, (128, 128))
@@ -86,6 +90,8 @@ def test_prim_func_no_sugar_shared_buffer_data():
     _assert_print(
         func,
         expected="""
+# from tvm.script import tir as T
+
 @T.prim_func
 def main(a: T.handle, b: T.handle):
     A = T.match_buffer(a, (128, 128))
@@ -698,8 +704,12 @@ def test_remap():
                 v3 = T.axis.spatial(128, i3 - 1)
                 v4, v5 = T.axis.remap("RS", [i4, i5])
 
-    expected_output = """@T.prim_func
+    expected_output = """
+# from tvm.script import tir as T
+
[email protected]_func
 def main():
+    # with T.block("root"):
     for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
         with T.block("update"):
             v0 = T.axis.spatial(128, i0 + 1)
@@ -731,8 +741,12 @@ def test_root_block():
                 with T.block():
                     T.evaluate(0)
 
-    expected_output = """@T.prim_func
+    expected_output = """
+# from tvm.script import tir as T
+
[email protected]_func
 def main():
+    # with T.block("root"):
     a = T.alloc_buffer((128, 128))
     for i, j in T.grid(128, 128):
         with T.block(""):

Reply via email to