This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b4c1995a98 [Node] Allow alternative root names in ObjectPath::Root()
(#14569)
b4c1995a98 is described below
commit b4c1995a98fc22a316a53c28b7eacb5240fc3f89
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 12 09:19:10 2023 -0500
[Node] Allow alternative root names in ObjectPath::Root() (#14569)
* [Node] Allow alternative root names in ObjectPath::Root()
Previously, the `ObjectPath` utility allowed tracking of an object's
location within a tree-like structure. However, the base of the path
structure was hard-coded to be the string `<root>`. For use cases
such as `StructuralEqual`, there is a clear root node. However, other
cases such as using `ObjectPath` to specify an object's location
relative to a known variable, would require using that known
variable's name as the root, rather than the hard-coded string
`<root>`.
This commit adds an optional parameter to provide an alternative name
for the root node, to allow for these use cases.
* Updated python API, added unit test
---
include/tvm/node/object_path.h | 6 ++++--
python/tvm/runtime/object_path.py | 6 ++++--
src/node/object_path.cc | 20 ++++++++++++++++----
tests/python/unittest/test_object_path.py | 10 ++++++++++
4 files changed, 34 insertions(+), 8 deletions(-)
diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h
index 35f947a68f..97a62bfd2d 100644
--- a/include/tvm/node/object_path.h
+++ b/include/tvm/node/object_path.h
@@ -122,7 +122,7 @@ class ObjectPathNode : public Object {
class ObjectPath : public ObjectRef {
public:
/*! \brief Create a path that represents the root object itself. */
- static ObjectPath Root();
+ static ObjectPath Root(Optional<String> name = NullOpt);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef,
ObjectPathNode);
};
@@ -135,7 +135,9 @@ class ObjectPath : public ObjectRef {
class RootPathNode final : public ObjectPathNode {
public:
- explicit RootPathNode();
+ Optional<String> name;
+
+ explicit RootPathNode(Optional<String> name = NullOpt);
static constexpr const char* _type_key = "RootPath";
TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);
diff --git a/python/tvm/runtime/object_path.py
b/python/tvm/runtime/object_path.py
index ecca85d53d..ff223b7599 100644
--- a/python/tvm/runtime/object_path.py
+++ b/python/tvm/runtime/object_path.py
@@ -20,6 +20,8 @@ ObjectPath class that represents a path from a root object to
one of its descend
via attribute access, array indexing etc.
"""
+from typing import Optional
+
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_node_api
@@ -52,8 +54,8 @@ class ObjectPath(Object):
)
@staticmethod
- def root() -> "ObjectPath":
- return _ffi_node_api.ObjectPathRoot()
+ def root(root_name: Optional[str] = None) -> "ObjectPath":
+ return _ffi_node_api.ObjectPathRoot(root_name)
def __eq__(self, other):
return _ffi_node_api.ObjectPathEqual(self, other)
diff --git a/src/node/object_path.cc b/src/node/object_path.cc
index 9c49daa8c3..4d88873e79 100644
--- a/src/node/object_path.cc
+++ b/src/node/object_path.cc
@@ -197,7 +197,9 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const {
// ============== ObjectPath ==============
-/* static */ ObjectPath ObjectPath::Root() { return
ObjectPath(make_object<RootPathNode>()); }
+/* static */ ObjectPath ObjectPath::Root(Optional<String> name) {
+ return ObjectPath(make_object<RootPathNode>(name));
+}
TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);
@@ -205,11 +207,21 @@
TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);
// ----- Root -----
-RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {}
+RootPathNode::RootPathNode(Optional<String> name) : ObjectPathNode(nullptr),
name(name) {}
+
+bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const {
+ const auto* other = static_cast<const RootPathNode*>(other_path);
-bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return
true; }
+ if (other->name.defined() != name.defined()) {
+ return false;
+ } else if (name && other->name) {
+ return name.value() == other->name.value();
+ } else {
+ return true;
+ }
+}
-std::string RootPathNode::LastNodeString() const { return "<root>"; }
+std::string RootPathNode::LastNodeString() const { return
name.value_or("<root>"); }
TVM_STATIC_IR_FUNCTOR(ReprPrinter,
vtable).set_dispatch<RootPathNode>(PrintObjectPathRepr);
diff --git a/tests/python/unittest/test_object_path.py
b/tests/python/unittest/test_object_path.py
index f849c129df..3fea5141c7 100644
--- a/tests/python/unittest/test_object_path.py
+++ b/tests/python/unittest/test_object_path.py
@@ -30,6 +30,16 @@ def test_root_path():
assert root.parent is None
+def test_named_root_path():
+ root = ObjectPath.root("base_name")
+ assert isinstance(root, object_path.RootPath)
+ assert str(root) == "base_name"
+ assert len(root) == 1
+ assert root != ObjectPath.root()
+ assert root == ObjectPath.root("base_name")
+ assert root.parent is None
+
+
def test_path_attr():
path = ObjectPath.root().attr("foo")
assert isinstance(path, object_path.AttributeAccessPath)