This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 819b0023e4 [Relax] Support nested ModuleList in nn.Module (#16971)
819b0023e4 is described below

commit 819b0023e46dd85a5ae8ce6294e5456abaf78f3c
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue May 7 06:09:32 2024 -0700

    [Relax] Support nested ModuleList in nn.Module (#16971)
---
 python/tvm/relax/frontend/nn/core.py           | 15 +++++++++------
 tests/python/relax/test_frontend_nn_modules.py | 15 +++++++++++++++
 2 files changed, 24 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 4953c1c817..46e016a242 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -607,16 +607,19 @@ def wrap_nested(expr: rx.Expr, name: str) -> 
Union[Tensor, Sequence[Tensor]]:
 
 def _attribute_finder(root: Module, prefix: str, condition_yield: 
Callable[[Any], bool]):
     """Find attributes that satisfy the condition recursively"""
+    if isinstance(root, ModuleList):
+        for i, subitem in enumerate(root):
+            yield from _attribute_finder(subitem, prefix + f"{i}.", 
condition_yield)
+        return
     for name, item in root.__dict__.items():
         if condition_yield(item):
             yield prefix + name, item
         elif isinstance(item, ModuleList):
-            for i, subitem in enumerate(item):
-                yield from _attribute_finder(
-                    subitem,
-                    prefix + name + f".{i}.",
-                    condition_yield,
-                )
+            yield from _attribute_finder(
+                item,
+                prefix + name + ".",
+                condition_yield,
+            )
         elif isinstance(item, Module):
             yield from _attribute_finder(
                 item,
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index 5ddc105055..23250f28aa 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -700,5 +700,20 @@ def test_nn_module_list_input():
     assert_structural_equal(tvm_mod["forward"], forward)
 
 
+def test_module_list():
+    class Module(nn.Module):
+        def __init__(self):
+            self.layers = nn.ModuleList(
+                [nn.ModuleList([nn.Linear(4, 4, bias=False) for _ in 
range(2)]) for _ in range(1)]
+            )
+
+        def forward(self, x: nn.Tensor):
+            return self.layers(x)
+
+    mod = Module()
+    named_params = dict(mod.named_parameters())
+    assert ["layers.0.0.weight", "layers.0.1.weight"] == 
sorted(list(named_params.keys()))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to