This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2329b1a9a9 [Fix] Update mutator name rule (#16046)
2329b1a9a9 is described below

commit 2329b1a9a9d4ffca831c8be19d84db9798f510ba
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Nov 3 11:05:04 2023 -0700

    [Fix] Update mutator name rule (#16046)
    
    * [Fix] Update mutator name rule
    
    * Update visitor.py
    
    ---------
    
    Co-authored-by: Junru Shao <[email protected]>
---
 python/tvm/relax/frontend/nn/visitor.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/visitor.py 
b/python/tvm/relax/frontend/nn/visitor.py
index 2e89a6672c..82f3010066 100644
--- a/python/tvm/relax/frontend/nn/visitor.py
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -16,6 +16,7 @@
 # under the License.
 """The visitor and mutator infra for nn.Module."""
 from typing import Any
+
 from . import core as nn
 
 
@@ -112,6 +113,15 @@ class Mutator:
         ret_node: Any
             The new node to replace current node.
         """
+
+        def _get_child_name(parent: str, child: str) -> str:
+            """Get the name of the child node/key given the parent's name."""
+            if parent == "":
+                # in the top level of the module
+                return child
+            else:
+                return f"{parent}.{child}"
+
         if isinstance(node, nn.ModuleList):
             for i in range(len(node)):
                 if isinstance(node[i], nn.ModuleList):
@@ -125,11 +135,11 @@ class Mutator:
         else:
             for key, value in node.__dict__.items():
                 if isinstance(value, nn.ModuleList):
-                    setattr(node, key, self.visit_modulelist(f"{name}.{key}", 
value))
+                    setattr(node, key, 
self.visit_modulelist(_get_child_name(name, key), value))
                 elif isinstance(value, nn.Module):
-                    setattr(node, key, self.visit_module(f"{name}.{key}", 
value))
+                    setattr(node, key, self.visit_module(_get_child_name(name, 
key), value))
                 elif isinstance(value, nn.Effect):
-                    setattr(node, key, self.visit_effect(f"{name}.{key}", 
value))
+                    setattr(node, key, self.visit_effect(_get_child_name(name, 
key), value))
                 elif isinstance(value, nn.Parameter):
-                    setattr(node, key, self.visit_param(f"{name}.{key}", 
value))
+                    setattr(node, key, self.visit_param(_get_child_name(name, 
key), value))
         return node

Reply via email to