This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0973248858 [TVMScript] Add source_paths to Doc (#12324)
0973248858 is described below

commit 09732488588f331bb1a46e1e2138567035797126
Author: Lite Ye <[email protected]>
AuthorDate: Fri Aug 5 16:33:38 2022 -0400

    [TVMScript] Add source_paths to Doc (#12324)
    
    This PR:
    
    - Add the source_paths attribute to Doc base class.
    - Add the corresponding Python binding for it.
    
    This PR is depended by multiple tasks, including the diagnostic output in 
DocPrinter, VarTable and IRDocisifer.
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
    
    Co-authored-by: Greg Bonik <[email protected]>
---
 include/tvm/script/printer/doc.h                   | 11 ++++++++-
 python/tvm/script/printer/doc.py                   | 27 ++++++++++++++++++++--
 src/script/printer/doc.cc                          |  4 ++++
 .../python/unittest/test_tvmscript_printer_doc.py  | 19 +++++++++++++++
 4 files changed, 58 insertions(+), 3 deletions(-)

diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 408c703d54..55faed33fb 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -40,7 +40,16 @@ namespace printer {
  */
 class DocNode : public Object {
  public:
-  void VisitAttrs(AttrVisitor* v) {}
+  /*!
+   * \brief The list of object paths of the source IR node.
+   *
+   * This is used to trace back to the IR node position where
+   * this Doc is generated, in order to position the diagnostic
+   * message.
+   */
+  mutable Array<ObjectPath> source_paths;
+
+  void VisitAttrs(AttrVisitor* v) { v->Visit("source_paths", &source_paths); }
 
   static constexpr const char* _type_key = "script.printer.Doc";
   TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object);
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index 0a5fde8975..5ac0723976 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -20,7 +20,7 @@ from enum import IntEnum, unique
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 from tvm._ffi import register_object
-from tvm.runtime import Object
+from tvm.runtime import Object, ObjectPath
 from tvm.tir import FloatImm, IntImm
 
 from . import _ffi_api
@@ -29,8 +29,23 @@ from . import _ffi_api
 class Doc(Object):
     """Base class of all Docs"""
 
+    @property
+    def source_paths(self) -> Sequence[ObjectPath]:
+        """
+        The list of object paths of the source IR node.
+
+        This is used to trace back to the IR node position where
+        this Doc is generated, in order to position the diagnostic
+        message.
+        """
+        return self.__getattr__("source_paths")  # pylint: 
disable=unnecessary-dunder-call
+
+    @source_paths.setter
+    def source_paths(self, value):
+        return _ffi_api.DocSetSourcePaths(self, value)  # type: ignore # 
pylint: disable=no-member
 
-class ExprDoc(Object):
+
+class ExprDoc(Doc):
     """Base class of all expression Docs"""
 
     def attr(self, name: str) -> "AttrAccessDoc":
@@ -104,6 +119,14 @@ class StmtDoc(Doc):
 
     @property
     def comment(self) -> Optional[str]:
+        """
+        The comment of this doc.
+
+        The actual position of the comment depends on the type of Doc
+        and also the DocPrinter implementation. It could be on the same
+        line as the statement, or the line above, or inside the statement
+        if it spans over multiple lines.
+        """
         # It has to call the dunder method to avoid infinite recursion
         return self.__getattr__("comment")  # pylint: 
disable=unnecessary-dunder-call
 
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index 2334d1fad5..b94d4c55bf 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -217,6 +217,10 @@ ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, 
Array<StmtDoc> body) {
 }
 
 TVM_REGISTER_NODE_TYPE(DocNode);
+TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
+    .set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
+      doc->source_paths = source_paths;
+    });
 
 TVM_REGISTER_NODE_TYPE(ExprDocNode);
 
TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py 
b/tests/python/unittest/test_tvmscript_printer_doc.py
index f27bc71b66..16a0c31ac3 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -22,6 +22,7 @@ Doc objects, then access and modify their attributes 
correctly.
 import pytest
 
 import tvm
+from tvm.runtime import ObjectPath
 from tvm.script.printer.doc import (
     AssertDoc,
     AssignDoc,
@@ -510,8 +511,26 @@ def test_stmt_doc_comment():
 
     comment = "test comment"
     doc.comment = comment
+    # Make sure the previous statement doesn't set attribute
+    # as if it's an ordinary Python object.
+    assert "comment" not in doc.__dict__
     assert doc.comment == comment
 
 
+def test_doc_source_paths():
+    doc = IdDoc("x")
+    assert len(doc.source_paths) == 0
+
+    source_paths = [ObjectPath.root(), ObjectPath.root().attr("x")]
+
+    doc.source_paths = source_paths
+    # This should triggers the __getattr__ and gets a tvm.ir.container.Array
+    assert not isinstance(doc.source_paths, list)
+    assert list(doc.source_paths) == source_paths
+
+    doc.source_paths = []
+    assert len(doc.source_paths) == 0
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to