This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2329b1a9a9 [Fix] Update mutator name rule (#16046)
2329b1a9a9 is described below
commit 2329b1a9a9d4ffca831c8be19d84db9798f510ba
Author: Lesheng Jin <[email protected]>
AuthorDate: Fri Nov 3 11:05:04 2023 -0700
[Fix] Update mutator name rule (#16046)
* [Fix] Update mutator name rule
* Update visitor.py
---------
Co-authored-by: Junru Shao <[email protected]>
---
python/tvm/relax/frontend/nn/visitor.py | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/visitor.py
b/python/tvm/relax/frontend/nn/visitor.py
index 2e89a6672c..82f3010066 100644
--- a/python/tvm/relax/frontend/nn/visitor.py
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -16,6 +16,7 @@
# under the License.
"""The visitor and mutator infra for nn.Module."""
from typing import Any
+
from . import core as nn
@@ -112,6 +113,15 @@ class Mutator:
ret_node: Any
The new node to replace current node.
"""
+
+ def _get_child_name(parent: str, child: str) -> str:
+ """Get the name of the child node/key given the parent's name."""
+ if parent == "":
+ # in the top level of the module
+ return child
+ else:
+ return f"{parent}.{child}"
+
if isinstance(node, nn.ModuleList):
for i in range(len(node)):
if isinstance(node[i], nn.ModuleList):
@@ -125,11 +135,11 @@ class Mutator:
else:
for key, value in node.__dict__.items():
if isinstance(value, nn.ModuleList):
- setattr(node, key, self.visit_modulelist(f"{name}.{key}",
value))
+ setattr(node, key,
self.visit_modulelist(_get_child_name(name, key), value))
elif isinstance(value, nn.Module):
- setattr(node, key, self.visit_module(f"{name}.{key}",
value))
+ setattr(node, key, self.visit_module(_get_child_name(name,
key), value))
elif isinstance(value, nn.Effect):
- setattr(node, key, self.visit_effect(f"{name}.{key}",
value))
+ setattr(node, key, self.visit_effect(_get_child_name(name,
key), value))
elif isinstance(value, nn.Parameter):
- setattr(node, key, self.visit_param(f"{name}.{key}",
value))
+ setattr(node, key, self.visit_param(_get_child_name(name,
key), value))
return node