gemini-code-assist[bot] commented on code in PR #19495:
URL: https://github.com/apache/tvm/pull/19495#discussion_r3176394649
##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -658,6 +711,40 @@ def forward(self, x): # pylint: disable=invalid-name
return x
+class ParameterList(Module):
+ """Holds parameters in a list."""
+
+ def __init__(self, params: list[Parameter] | None = None):
+ self.params: list[Parameter] = []
+ if params is not None:
+ self.extend(params)
+
+ def __iter__(self) -> Iterator[Parameter]:
+ return iter(self.params)
+
+ def __getitem__(self, idx: int) -> Parameter:
+ return self.params[idx]
+
+ def __setitem__(self, idx: int, param: Parameter) -> None:
+ self.params[idx] = param
Review Comment:

The `ParameterList` should validate that values are `nn.Parameter` instances
in `__setitem__` to maintain type safety and satisfy the requirements of the
added tests.
```suggestion
def __setitem__(self, idx: int, param: Parameter) -> None:
if not isinstance(param, Parameter):
raise TypeError(f"ParameterList elements must be nn.Parameter,
but got {type(param).__name__}")
self.params[idx] = param
```
##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -625,6 +625,59 @@ def to(self, dtype: str | None = None) -> None: # pylint:
disable=invalid-name
module.to(dtype=dtype)
+class ParameterDict(Module):
+ """Holds parameters in a dict."""
+
+ def __init__(
+ self,
+ params: OrderedDict[str, Parameter] | dict[str, Parameter] | None =
None,
+ ):
+ self.params: OrderedDict[str, Parameter] = OrderedDict()
+ if params is not None:
+ self.update(params)
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self.params)
+
+ def __getitem__(self, key: str) -> Parameter:
+ return self.params[key]
+
+ def __setitem__(self, key: str, param: Parameter) -> None:
+ self.params[key] = param
Review Comment:

The `ParameterDict` should validate that keys are strings and values are
`nn.Parameter` instances. This ensures consistency with the PR's stated goal of
avoiding automatic conversions and allows the `test_type_validation` test case
to pass as expected.
```suggestion
def __setitem__(self, key: str, param: Parameter) -> None:
if not isinstance(key, str):
raise TypeError(f"ParameterDict keys must be strings, but got
{type(key).__name__}")
if not isinstance(param, Parameter):
raise TypeError(f"ParameterDict values must be nn.Parameter, but
got {type(param).__name__}")
self.params[key] = param
```
##########
python/tvm/relax/frontend/nn/visitor.py:
##########
@@ -141,9 +177,19 @@ def _get_child_name(parent: str, child: str) -> str:
else:
return f"{parent}.{child}"
- if isinstance(node, nn.ModuleList):
+ if isinstance(node, nn.ParameterList):
+ for i in range(len(node)):
+ node[i] = self.visit_param(_get_child_name(name, str(i)),
node[i])
+ elif isinstance(node, nn.ParameterDict):
+ for k, v in node.items():
+ node[k] = self.visit_param(_get_child_name(name, k), v)
+ elif isinstance(node, nn.ModuleList):
for i in range(len(node)):
- if isinstance(node[i], nn.ModuleDict):
+ if isinstance(node[i], nn.ParameterDict):
+ node[i] = self.visit_parameterdict(_get_child_name(name,
str(i)), node[i])
+ elif isinstance(node[i], nn.ParameterList):
+ node[i] = self.visit_parameterlist(_get_child_name(name,
str(i)), node[i])
+ elif isinstance(node[i], nn.ModuleDict):
node[i] = self.visit_moduledict(f"{name}.{i}", node[i])
elif isinstance(node[i], nn.ModuleList):
node[i] = self.visit_modulelist(f"{name}.{i}", node[i])
Review Comment:

For consistency and to correctly handle cases where `name` might be an empty
string, use the `_get_child_name` helper function here, as is done for the new
`ParameterDict` and `ParameterList` cases.
```suggestion
node[i] = self.visit_moduledict(_get_child_name(name,
str(i)), node[i])
elif isinstance(node[i], nn.ModuleList):
node[i] = self.visit_modulelist(_get_child_name(name,
str(i)), node[i])
```
##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -658,6 +711,40 @@ def forward(self, x): # pylint: disable=invalid-name
return x
+class ParameterList(Module):
+ """Holds parameters in a list."""
+
+ def __init__(self, params: list[Parameter] | None = None):
+ self.params: list[Parameter] = []
+ if params is not None:
+ self.extend(params)
+
+ def __iter__(self) -> Iterator[Parameter]:
+ return iter(self.params)
+
+ def __getitem__(self, idx: int) -> Parameter:
+ return self.params[idx]
+
+ def __setitem__(self, idx: int, param: Parameter) -> None:
+ self.params[idx] = param
+
+ def __len__(self) -> int:
+ return len(self.params)
+
+ def append(self, param: Parameter) -> None:
+ """Add a parameter to the end of the ParameterList"""
+ self.params.append(param)
Review Comment:

The `append` method should also include type validation for the `param`
argument.
```suggestion
def append(self, param: Parameter) -> None:
"""Add a parameter to the end of the ParameterList"""
if not isinstance(param, Parameter):
raise TypeError(f"ParameterList elements must be nn.Parameter,
but got {type(param).__name__}")
self.params.append(param)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]