This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 819b0023e4 [Relax] Support nested ModuleList in nn.Module (#16971)
819b0023e4 is described below
commit 819b0023e46dd85a5ae8ce6294e5456abaf78f3c
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue May 7 06:09:32 2024 -0700
[Relax] Support nested ModuleList in nn.Module (#16971)
---
python/tvm/relax/frontend/nn/core.py | 15 +++++++++------
tests/python/relax/test_frontend_nn_modules.py | 15 +++++++++++++++
2 files changed, 24 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 4953c1c817..46e016a242 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -607,16 +607,19 @@ def wrap_nested(expr: rx.Expr, name: str) ->
Union[Tensor, Sequence[Tensor]]:
def _attribute_finder(root: Module, prefix: str, condition_yield:
Callable[[Any], bool]):
"""Find attributes that satisfy the condition recursively"""
+ if isinstance(root, ModuleList):
+ for i, subitem in enumerate(root):
+ yield from _attribute_finder(subitem, prefix + f"{i}.",
condition_yield)
+ return
for name, item in root.__dict__.items():
if condition_yield(item):
yield prefix + name, item
elif isinstance(item, ModuleList):
- for i, subitem in enumerate(item):
- yield from _attribute_finder(
- subitem,
- prefix + name + f".{i}.",
- condition_yield,
- )
+ yield from _attribute_finder(
+ item,
+ prefix + name + ".",
+ condition_yield,
+ )
elif isinstance(item, Module):
yield from _attribute_finder(
item,
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 5ddc105055..23250f28aa 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -700,5 +700,20 @@ def test_nn_module_list_input():
assert_structural_equal(tvm_mod["forward"], forward)
+def test_module_list():
+ class Module(nn.Module):
+ def __init__(self):
+ self.layers = nn.ModuleList(
+ [nn.ModuleList([nn.Linear(4, 4, bias=False) for _ in
range(2)]) for _ in range(1)]
+ )
+
+ def forward(self, x: nn.Tensor):
+ return self.layers(x)
+
+ mod = Module()
+ named_params = dict(mod.named_parameters())
+ assert ["layers.0.0.weight", "layers.0.1.weight"] ==
sorted(list(named_params.keys()))
+
+
if __name__ == "__main__":
tvm.testing.main()