This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b4a1e2fef [Relax][Frontend] Introduce ModuleDict (#18551)
2b4a1e2fef is described below

commit 2b4a1e2fefb226127b950528689a8b7947ad43bd
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Dec 7 13:24:12 2025 +0900

    [Relax][Frontend] Introduce ModuleDict (#18551)
    
    As per title.
    Just like [ModuleDict in
    
PyTorch](https://docs.pytorch.org/docs/stable/generated/torch.nn.ModuleDict.html).
---
 python/tvm/relax/frontend/nn/__init__.py       |  2 +-
 python/tvm/relax/frontend/nn/core.py           | 61 +++++++++++++++++++++++++
 python/tvm/relax/frontend/nn/visitor.py        | 40 ++++++++++++++--
 tests/python/relax/test_frontend_nn_modules.py | 17 +++++++
 tests/python/relax/test_frontend_nn_mutator.py | 63 +++++++++++++++++++++++++-
 5 files changed, 178 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/__init__.py 
b/python/tvm/relax/frontend/nn/__init__.py
index f490af7062..d903634883 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -17,7 +17,7 @@
 """A PyTorch-like API to build IRModules."""
 # pylint: disable=redefined-builtin
 from . import op, spec
-from .core import Effect, Module, ModuleList, Object, Parameter, Tensor
+from .core import Effect, Module, ModuleDict, ModuleList, Object, Parameter, 
Tensor
 from .exporter import add_extern
 from .extern import ExternModule, ObjectModule, SourceModule
 from .modules import (
diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 8529dda006..b15ba685b7 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -540,6 +540,56 @@ class Module(SubroutineMixin):
         raise ValueError(f"Unknown out_format: {out_format}")
 
 
+class ModuleDict(Module):
+    """Holds submodules in a dict."""
+
+    def __init__(self, modules: Optional[OrderedDict[str, Module]] = None):
+        if modules is None:
+            self.modules = OrderedDict()
+        else:
+            self.modules = OrderedDict(modules)
+
+    def __iter__(self):
+        return iter(self.modules.values())
+
+    def __getitem__(self, key: str) -> Module:
+        return self.modules[key]
+
+    def __setitem__(self, key: str, module: Module) -> None:
+        self.modules[key] = module
+
+    def __len__(self) -> int:
+        return len(self.modules)
+
+    def keys(self) -> Iterator[str]:
+        return self.modules.keys()
+
+    def values(self) -> Iterator[Module]:
+        return self.modules.values()
+
+    def items(self) -> Iterator[Tuple[str, Module]]:
+        return self.modules.items()
+
+    def get(self, key: str, default: Optional[Module] = None) -> 
Optional[Module]:
+        return self.modules.get(key, default)
+
+    def update(self, modules: Dict[str, Module]) -> None:
+        self.modules.update(modules)
+
+    def clear(self) -> None:
+        self.modules.clear()
+
+    def pop(self, key: str) -> Module:
+        return self.modules.pop(key)
+
+    def __contains__(self, key: str) -> bool:
+        return key in self.modules
+
+    def to(self, dtype: Optional[str] = None) -> None:  # pylint: 
disable=invalid-name
+        for module in self.modules.values():
+            module.to(dtype=dtype)
+
+
 class ModuleList(Module):
     """Holds submodules in a list."""
 
@@ -611,6 +661,10 @@ def _attribute_finder(root: Module, prefix: str, 
condition_yield: Callable[[Any]
         for i, subitem in enumerate(root):
             yield from _attribute_finder(subitem, prefix + f"{i}.", 
condition_yield)
         return
+    elif isinstance(root, ModuleDict):
+        for name, subitem in root.items():
+            yield from _attribute_finder(subitem, prefix + f"{name}.", 
condition_yield)
+        return
     for name, item in root.__dict__.items():
         if condition_yield(item):
             yield prefix + name, item
@@ -620,6 +674,13 @@ def _attribute_finder(root: Module, prefix: str, 
condition_yield: Callable[[Any]
                 prefix + name + ".",
                 condition_yield,
             )
+        elif isinstance(item, ModuleDict):
+            for sub_name, sub_item in item.items():
+                yield from _attribute_finder(
+                    sub_item,
+                    prefix + name + f".{sub_name}.",
+                    condition_yield,
+                )
         elif isinstance(item, Module):
             yield from _attribute_finder(
                 item,
diff --git a/python/tvm/relax/frontend/nn/visitor.py 
b/python/tvm/relax/frontend/nn/visitor.py
index 82f3010066..d2467a2bf8 100644
--- a/python/tvm/relax/frontend/nn/visitor.py
+++ b/python/tvm/relax/frontend/nn/visitor.py
@@ -79,6 +79,24 @@ class Mutator:
         """
         return self.visit(name, node)
 
+    def visit_moduledict(self, name: str, node: nn.ModuleDict) -> Any:
+        """The base visiting method for mutation of nn.ModuleDict nodes.
+
+        Parameters
+        ----------
+        name : str
+            The name of the current node in parent's attribute.
+
+        node : nn.ModuleDict
+            The current node of nn.ModuleDict to mutate.
+
+        Returns
+        ------
+        ret_node: Any
+            The new node to replace current node.
+        """
+        return self.visit(name, node)
+
     def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any:
         """The base visiting method for mutation of nn.ModuleList nodes.
 
@@ -88,7 +106,7 @@ class Mutator:
             The name of the current node in parent's attribute.
 
         node : nn.ModuleList
-            The current node of nn.MoModuleListdule to mutate.
+            The current node of nn.ModuleList to mutate.
 
         Returns
         ------
@@ -124,7 +142,9 @@ class Mutator:
 
         if isinstance(node, nn.ModuleList):
             for i in range(len(node)):
-                if isinstance(node[i], nn.ModuleList):
+                if isinstance(node[i], nn.ModuleDict):
+                    node[i] = self.visit_moduledict(f"{name}.{i}", node[i])
+                elif isinstance(node[i], nn.ModuleList):
                     node[i] = self.visit_modulelist(f"{name}.{i}", node[i])
                 elif isinstance(node[i], nn.Module):
                     node[i] = self.visit_module(f"{name}.{i}", node[i])
@@ -132,9 +152,23 @@ class Mutator:
                     node[i] = self.visit_effect(f"{name}.{i}", node[i])
                 elif isinstance(node[i], nn.Parameter):
                     node[i] = self.visit_param(f"{name}.{i}", node[i])
+        elif isinstance(node, nn.ModuleDict):
+            for k, v in node.items():
+                if isinstance(v, nn.ModuleDict):
+                    node[k] = self.visit_moduledict(_get_child_name(name, k), 
v)
+                elif isinstance(v, nn.ModuleList):
+                    node[k] = self.visit_modulelist(_get_child_name(name, k), 
v)
+                elif isinstance(v, nn.Module):
+                    node[k] = self.visit_module(_get_child_name(name, k), v)
+                elif isinstance(v, nn.Effect):
+                    node[k] = self.visit_effect(_get_child_name(name, k), v)
+                elif isinstance(v, nn.Parameter):
+                    node[k] = self.visit_param(_get_child_name(name, k), v)
         else:
             for key, value in node.__dict__.items():
-                if isinstance(value, nn.ModuleList):
+                if isinstance(value, nn.ModuleDict):
+                    setattr(node, key, 
self.visit_moduledict(_get_child_name(name, key), value))
+                elif isinstance(value, nn.ModuleList):
                     setattr(node, key, 
self.visit_modulelist(_get_child_name(name, key), value))
                 elif isinstance(value, nn.Module):
                     setattr(node, key, self.visit_module(_get_child_name(name, 
key), value))
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index 23250f28aa..e9a4a6f624 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -715,5 +715,22 @@ def test_module_list():
     assert ["layers.0.0.weight", "layers.0.1.weight"] == 
sorted(list(named_params.keys()))
 
 
+def test_module_dict():
+    class Module(nn.Module):
+        def __init__(self):
+            self.layers = nn.ModuleDict(
+                {"linear0": nn.Linear(4, 4, bias=False), "linear1": 
nn.Linear(4, 4, bias=False)}
+            )
+
+        def forward(self, x: nn.Tensor):
+            x = self.layers["linear0"](x)
+            x = self.layers["linear1"](x)
+            return x
+
+    mod = Module()
+    named_params = dict(mod.named_parameters())
+    assert ["layers.linear0.weight", "layers.linear1.weight"] == 
sorted(list(named_params.keys()))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_mutator.py 
b/tests/python/relax/test_frontend_nn_mutator.py
index ffb6586159..253e24a4ed 100644
--- a/tests/python/relax/test_frontend_nn_mutator.py
+++ b/tests/python/relax/test_frontend_nn_mutator.py
@@ -65,6 +65,37 @@ def test_mutator_naming_basic():
     mutator.visit("mod3", mod3)
 
 
+def test_mutator_naming_moduledict():
+    class Module(nn.Module):
+        def __init__(self, dtype) -> None:
+            super().__init__()
+            self.param = nn.Parameter((32, 128), dtype)
+
+    class Mutator(nn.Mutator):
+        def visit_param(self, name: str, node: nn.Parameter) -> Any:
+            if node.dtype == "float64":
+                assert name == "mod_dict.k0.0.param"
+                return node
+            elif node.dtype == "float32":
+                assert name == "mod_dict.k0.1.param"
+                return node
+            elif node.dtype == "float16":
+                assert name == "mod_dict.k1.0.param"
+                return node
+            elif node.dtype == "float8":
+                assert name == "mod_dict.k1.1.param"
+                return node
+
+    mod_dict = nn.ModuleDict(
+        {
+            "k0": nn.ModuleList([Module("float64"), Module("float32")]),
+            "k1": nn.ModuleList([Module("float16"), Module("float8")]),
+        }
+    )
+    mutator = Mutator()
+    mutator.visit("mod_dict", mod_dict)
+
+
 def test_mutator_naming_modulelist():
     class Module(nn.Module):
         def __init__(self, dtype) -> None:
@@ -124,6 +155,37 @@ def test_mutator_module():
     assert isinstance(module.mod, SubModule2)
 
 
+def test_mutator_moduledict():
+    class Module1(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class Module2(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class Module3(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+
+    class Mutator(nn.Mutator):
+        def visit_module(self, name: str, node: nn.Module) -> Any:
+            if isinstance(node, Module3):
+                return Module1()
+            else:
+                return node
+
+    mutator = Mutator()
+    module_dict = nn.ModuleDict({"k0": Module1(), "k1": Module2(), "k2": 
Module3()})
+    assert isinstance(module_dict["k0"], Module1)
+    assert isinstance(module_dict["k1"], Module2)
+    assert isinstance(module_dict["k2"], Module3)
+    module_dict = mutator.visit("", module_dict)
+    assert isinstance(module_dict["k0"], Module1)
+    assert isinstance(module_dict["k1"], Module2)
+    assert isinstance(module_dict["k2"], Module1)
+
+
 def test_mutator_modulelist():
     class Module1(nn.Module):
         def __init__(self) -> None:
@@ -150,7 +212,6 @@ def test_mutator_modulelist():
     assert isinstance(module_list[1], Module2)
     assert isinstance(module_list[2], Module3)
     module_list = mutator.visit("", module_list)
-    print(module_list[2])
     assert isinstance(module_list[0], Module1)
     assert isinstance(module_list[1], Module2)
     assert isinstance(module_list[2], Module1)

Reply via email to