This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0973248858 [TVMScript] Add source_paths to Doc (#12324)
0973248858 is described below
commit 09732488588f331bb1a46e1e2138567035797126
Author: Lite Ye <[email protected]>
AuthorDate: Fri Aug 5 16:33:38 2022 -0400
[TVMScript] Add source_paths to Doc (#12324)
This PR:
- Add the source_paths attribute to Doc base class.
- Add the corresponding Python binding for it.
This PR is depended by multiple tasks, including the diagnostic output in
DocPrinter, VarTable and IRDocisifer.
Tracking issue: https://github.com/apache/tvm/issues/11912
Co-authored-by: Greg Bonik <[email protected]>
---
include/tvm/script/printer/doc.h | 11 ++++++++-
python/tvm/script/printer/doc.py | 27 ++++++++++++++++++++--
src/script/printer/doc.cc | 4 ++++
.../python/unittest/test_tvmscript_printer_doc.py | 19 +++++++++++++++
4 files changed, 58 insertions(+), 3 deletions(-)
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 408c703d54..55faed33fb 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -40,7 +40,16 @@ namespace printer {
*/
class DocNode : public Object {
public:
- void VisitAttrs(AttrVisitor* v) {}
+ /*!
+ * \brief The list of object paths of the source IR node.
+ *
+ * This is used to trace back to the IR node position where
+ * this Doc is generated, in order to position the diagnostic
+ * message.
+ */
+ mutable Array<ObjectPath> source_paths;
+
+ void VisitAttrs(AttrVisitor* v) { v->Visit("source_paths", &source_paths); }
static constexpr const char* _type_key = "script.printer.Doc";
TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object);
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index 0a5fde8975..5ac0723976 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -20,7 +20,7 @@ from enum import IntEnum, unique
from typing import Dict, List, Optional, Sequence, Tuple, Union
from tvm._ffi import register_object
-from tvm.runtime import Object
+from tvm.runtime import Object, ObjectPath
from tvm.tir import FloatImm, IntImm
from . import _ffi_api
@@ -29,8 +29,23 @@ from . import _ffi_api
class Doc(Object):
"""Base class of all Docs"""
+ @property
+ def source_paths(self) -> Sequence[ObjectPath]:
+ """
+ The list of object paths of the source IR node.
+
+ This is used to trace back to the IR node position where
+ this Doc is generated, in order to position the diagnostic
+ message.
+ """
+ return self.__getattr__("source_paths") # pylint:
disable=unnecessary-dunder-call
+
+ @source_paths.setter
+ def source_paths(self, value):
+ return _ffi_api.DocSetSourcePaths(self, value) # type: ignore #
pylint: disable=no-member
-class ExprDoc(Object):
+
+class ExprDoc(Doc):
"""Base class of all expression Docs"""
def attr(self, name: str) -> "AttrAccessDoc":
@@ -104,6 +119,14 @@ class StmtDoc(Doc):
@property
def comment(self) -> Optional[str]:
+ """
+ The comment of this doc.
+
+ The actual position of the comment depends on the type of Doc
+ and also the DocPrinter implementation. It could be on the same
+ line as the statement, or the line above, or inside the statement
+ if it spans over multiple lines.
+ """
# It has to call the dunder method to avoid infinite recursion
return self.__getattr__("comment") # pylint:
disable=unnecessary-dunder-call
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index 2334d1fad5..b94d4c55bf 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -217,6 +217,10 @@ ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators,
Array<StmtDoc> body) {
}
TVM_REGISTER_NODE_TYPE(DocNode);
+TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
+ .set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
+ doc->source_paths = source_paths;
+ });
TVM_REGISTER_NODE_TYPE(ExprDocNode);
TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py
b/tests/python/unittest/test_tvmscript_printer_doc.py
index f27bc71b66..16a0c31ac3 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -22,6 +22,7 @@ Doc objects, then access and modify their attributes
correctly.
import pytest
import tvm
+from tvm.runtime import ObjectPath
from tvm.script.printer.doc import (
AssertDoc,
AssignDoc,
@@ -510,8 +511,26 @@ def test_stmt_doc_comment():
comment = "test comment"
doc.comment = comment
+ # Make sure the previous statement doesn't set attribute
+ # as if it's an ordinary Python object.
+ assert "comment" not in doc.__dict__
assert doc.comment == comment
+def test_doc_source_paths():
+ doc = IdDoc("x")
+ assert len(doc.source_paths) == 0
+
+ source_paths = [ObjectPath.root(), ObjectPath.root().attr("x")]
+
+ doc.source_paths = source_paths
+ # This should triggers the __getattr__ and gets a tvm.ir.container.Array
+ assert not isinstance(doc.source_paths, list)
+ assert list(doc.source_paths) == source_paths
+
+ doc.source_paths = []
+ assert len(doc.source_paths) == 0
+
+
if __name__ == "__main__":
tvm.testing.main()