gemini-code-assist[bot] commented on code in PR #19495:
URL: https://github.com/apache/tvm/pull/19495#discussion_r3176394649


##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -658,6 +711,40 @@ def forward(self, x):  # pylint: disable=invalid-name
         return x
 
 
+class ParameterList(Module):
+    """Holds parameters in a list."""
+
+    def __init__(self, params: list[Parameter] | None = None):
+        self.params: list[Parameter] = []
+        if params is not None:
+            self.extend(params)
+
+    def __iter__(self) -> Iterator[Parameter]:
+        return iter(self.params)
+
+    def __getitem__(self, idx: int) -> Parameter:
+        return self.params[idx]
+
+    def __setitem__(self, idx: int, param: Parameter) -> None:
+        self.params[idx] = param

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `ParameterList` should validate that values are `nn.Parameter` instances 
in `__setitem__` to maintain type safety and satisfy the requirements of the 
added tests.
   
   ```suggestion
       def __setitem__(self, idx: int, param: Parameter) -> None:
           if not isinstance(param, Parameter):
               raise TypeError(f"ParameterList elements must be nn.Parameter, 
but got {type(param).__name__}")
           self.params[idx] = param
   ```



##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -625,6 +625,59 @@ def to(self, dtype: str | None = None) -> None:  # pylint: 
disable=invalid-name
             module.to(dtype=dtype)
 
 
+class ParameterDict(Module):
+    """Holds parameters in a dict."""
+
+    def __init__(
+        self,
+        params: OrderedDict[str, Parameter] | dict[str, Parameter] | None = 
None,
+    ):
+        self.params: OrderedDict[str, Parameter] = OrderedDict()
+        if params is not None:
+            self.update(params)
+
+    def __iter__(self) -> Iterator[str]:
+        return iter(self.params)
+
+    def __getitem__(self, key: str) -> Parameter:
+        return self.params[key]
+
+    def __setitem__(self, key: str, param: Parameter) -> None:
+        self.params[key] = param

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `ParameterDict` should validate that keys are strings and values are 
`nn.Parameter` instances. This ensures consistency with the PR's stated goal of 
avoiding automatic conversions and allows the `test_type_validation` test case 
to pass as expected.
   
   ```suggestion
       def __setitem__(self, key: str, param: Parameter) -> None:
           if not isinstance(key, str):
               raise TypeError(f"ParameterDict keys must be strings, but got 
{type(key).__name__}")
           if not isinstance(param, Parameter):
               raise TypeError(f"ParameterDict values must be nn.Parameter, but 
got {type(param).__name__}")
           self.params[key] = param
   ```



##########
python/tvm/relax/frontend/nn/visitor.py:
##########
@@ -141,9 +177,19 @@ def _get_child_name(parent: str, child: str) -> str:
             else:
                 return f"{parent}.{child}"
 
-        if isinstance(node, nn.ModuleList):
+        if isinstance(node, nn.ParameterList):
+            for i in range(len(node)):
+                node[i] = self.visit_param(_get_child_name(name, str(i)), 
node[i])
+        elif isinstance(node, nn.ParameterDict):
+            for k, v in node.items():
+                node[k] = self.visit_param(_get_child_name(name, k), v)
+        elif isinstance(node, nn.ModuleList):
             for i in range(len(node)):
-                if isinstance(node[i], nn.ModuleDict):
+                if isinstance(node[i], nn.ParameterDict):
+                    node[i] = self.visit_parameterdict(_get_child_name(name, 
str(i)), node[i])
+                elif isinstance(node[i], nn.ParameterList):
+                    node[i] = self.visit_parameterlist(_get_child_name(name, 
str(i)), node[i])
+                elif isinstance(node[i], nn.ModuleDict):
                     node[i] = self.visit_moduledict(f"{name}.{i}", node[i])
                 elif isinstance(node[i], nn.ModuleList):
                     node[i] = self.visit_modulelist(f"{name}.{i}", node[i])

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   For consistency and to correctly handle cases where `name` might be an empty 
string, use the `_get_child_name` helper function here, as is done for the new 
`ParameterDict` and `ParameterList` cases.
   
   ```suggestion
                       node[i] = self.visit_moduledict(_get_child_name(name, 
str(i)), node[i])
                   elif isinstance(node[i], nn.ModuleList):
                       node[i] = self.visit_modulelist(_get_child_name(name, 
str(i)), node[i])
   ```



##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -658,6 +711,40 @@ def forward(self, x):  # pylint: disable=invalid-name
         return x
 
 
+class ParameterList(Module):
+    """Holds parameters in a list."""
+
+    def __init__(self, params: list[Parameter] | None = None):
+        self.params: list[Parameter] = []
+        if params is not None:
+            self.extend(params)
+
+    def __iter__(self) -> Iterator[Parameter]:
+        return iter(self.params)
+
+    def __getitem__(self, idx: int) -> Parameter:
+        return self.params[idx]
+
+    def __setitem__(self, idx: int, param: Parameter) -> None:
+        self.params[idx] = param
+
+    def __len__(self) -> int:
+        return len(self.params)
+
+    def append(self, param: Parameter) -> None:
+        """Add a parameter to the end of the ParameterList"""
+        self.params.append(param)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `append` method should also include type validation for the `param` 
argument.
   
   ```suggestion
       def append(self, param: Parameter) -> None:
           """Add a parameter to the end of the ParameterList"""
           if not isinstance(param, Parameter):
               raise TypeError(f"ParameterList elements must be nn.Parameter, 
but got {type(param).__name__}")
           self.params.append(param)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to