This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b4c1995a98 [Node] Allow alternative root names in ObjectPath::Root() 
(#14569)
b4c1995a98 is described below

commit b4c1995a98fc22a316a53c28b7eacb5240fc3f89
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 12 09:19:10 2023 -0500

    [Node] Allow alternative root names in ObjectPath::Root() (#14569)
    
    * [Node] Allow alternative root names in ObjectPath::Root()
    
    Previously, the `ObjectPath` utility allowed tracking of an object's
    location within a tree-like structure.  However, the base of the path
    structure was hard-coded to be the string `<root>`.  For use cases
    such as `StructuralEqual`, there is a clear root node.  However, other
    cases such as using `ObjectPath` to specify an object's location
    relative to a known variable, would require using that known
    variable's name as the root, rather than the hard-coded string
    `<root>`.
    
    This commit adds an optional parameter to provide an alternative name
    for the root node, to allow for these use cases.
    
    * Updated python API, added unit test
---
 include/tvm/node/object_path.h            |  6 ++++--
 python/tvm/runtime/object_path.py         |  6 ++++--
 src/node/object_path.cc                   | 20 ++++++++++++++++----
 tests/python/unittest/test_object_path.py | 10 ++++++++++
 4 files changed, 34 insertions(+), 8 deletions(-)

diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h
index 35f947a68f..97a62bfd2d 100644
--- a/include/tvm/node/object_path.h
+++ b/include/tvm/node/object_path.h
@@ -122,7 +122,7 @@ class ObjectPathNode : public Object {
 class ObjectPath : public ObjectRef {
  public:
   /*! \brief Create a path that represents the root object itself. */
-  static ObjectPath Root();
+  static ObjectPath Root(Optional<String> name = NullOpt);
 
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, 
ObjectPathNode);
 };
@@ -135,7 +135,9 @@ class ObjectPath : public ObjectRef {
 
 class RootPathNode final : public ObjectPathNode {
  public:
-  explicit RootPathNode();
+  Optional<String> name;
+
+  explicit RootPathNode(Optional<String> name = NullOpt);
 
   static constexpr const char* _type_key = "RootPath";
   TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);
diff --git a/python/tvm/runtime/object_path.py 
b/python/tvm/runtime/object_path.py
index ecca85d53d..ff223b7599 100644
--- a/python/tvm/runtime/object_path.py
+++ b/python/tvm/runtime/object_path.py
@@ -20,6 +20,8 @@ ObjectPath class that represents a path from a root object to 
one of its descend
 via attribute access, array indexing etc.
 """
 
+from typing import Optional
+
 import tvm._ffi
 from tvm.runtime import Object
 from . import _ffi_node_api
@@ -52,8 +54,8 @@ class ObjectPath(Object):
         )
 
     @staticmethod
-    def root() -> "ObjectPath":
-        return _ffi_node_api.ObjectPathRoot()
+    def root(root_name: Optional[str] = None) -> "ObjectPath":
+        return _ffi_node_api.ObjectPathRoot(root_name)
 
     def __eq__(self, other):
         return _ffi_node_api.ObjectPathEqual(self, other)
diff --git a/src/node/object_path.cc b/src/node/object_path.cc
index 9c49daa8c3..4d88873e79 100644
--- a/src/node/object_path.cc
+++ b/src/node/object_path.cc
@@ -197,7 +197,9 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const {
 
 // ============== ObjectPath ==============
 
-/* static */ ObjectPath ObjectPath::Root() { return 
ObjectPath(make_object<RootPathNode>()); }
+/* static */ ObjectPath ObjectPath::Root(Optional<String> name) {
+  return ObjectPath(make_object<RootPathNode>(name));
+}
 
 TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);
 
@@ -205,11 +207,21 @@ 
TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);
 
 // ----- Root -----
 
-RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {}
+RootPathNode::RootPathNode(Optional<String> name) : ObjectPathNode(nullptr), 
name(name) {}
+
+bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const {
+  const auto* other = static_cast<const RootPathNode*>(other_path);
 
-bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return 
true; }
+  if (other->name.defined() != name.defined()) {
+    return false;
+  } else if (name && other->name) {
+    return name.value() == other->name.value();
+  } else {
+    return true;
+  }
+}
 
-std::string RootPathNode::LastNodeString() const { return "<root>"; }
+std::string RootPathNode::LastNodeString() const { return 
name.value_or("<root>"); }
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, 
vtable).set_dispatch<RootPathNode>(PrintObjectPathRepr);
 
diff --git a/tests/python/unittest/test_object_path.py 
b/tests/python/unittest/test_object_path.py
index f849c129df..3fea5141c7 100644
--- a/tests/python/unittest/test_object_path.py
+++ b/tests/python/unittest/test_object_path.py
@@ -30,6 +30,16 @@ def test_root_path():
     assert root.parent is None
 
 
+def test_named_root_path():
+    root = ObjectPath.root("base_name")
+    assert isinstance(root, object_path.RootPath)
+    assert str(root) == "base_name"
+    assert len(root) == 1
+    assert root != ObjectPath.root()
+    assert root == ObjectPath.root("base_name")
+    assert root.parent is None
+
+
 def test_path_attr():
     path = ObjectPath.root().attr("foo")
     assert isinstance(path, object_path.AttributeAccessPath)

Reply via email to